diff --git a/.travis.yml b/.travis.yml index 8186b425cd..64ab4bda39 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,10 +2,9 @@ language: python python: - "3.6" install: - - pip install -r requirements.txt - - pip install -U .[all] - - pip install codecov + - make setup script: + - make lint - coverage run -m pytest tests/flytekit/unit - shellcheck **/*.sh after_success: diff --git a/Makefile b/Makefile new file mode 100644 index 0000000000..ae01d4b566 --- /dev/null +++ b/Makefile @@ -0,0 +1,43 @@ +define PIP_COMPILE +pip-compile $(1) --upgrade --verbose +endef + +.SILENT: help +.PHONY: help +help: + echo Available recipes: + cat $(MAKEFILE_LIST) | grep -E '^[a-zA-Z0-9_-]+:.*?## .*$$' | awk 'BEGIN { FS = ":.*?## " } { cnt++; a[cnt] = $$1; b[cnt] = $$2; if (length($$1) > max) max = length($$1) } END { for (i = 1; i <= cnt; i++) printf " $(shell tput setaf 6)%-*s$(shell tput setaf 0) %s\n", max, a[i], b[i] }' + tput sgr0 + +.PHONY: _install-piptools +_install-piptools: + pip install -U pip-tools + +.PHONY: setup +setup: _install-piptools ## Install requirements + pip-sync requirements.txt dev-requirements.txt + +.PHONY: fmt +fmt: ## Format code with black and isort + black . + isort . + +.PHONY: lint +lint: ## Run linters + flake8 . + +.PHONY: test +test: lint ## Run tests + pytest tests/flytekit/unit + shellcheck **/*.sh + +requirements.txt: export CUSTOM_COMPILE_COMMAND := make requirements.txt +requirements.txt: requirements.in _install-piptools + $(call PIP_COMPILE,requirements.in) + +dev-requirements.txt: export CUSTOM_COMPILE_COMMAND := make dev-requirements.txt +dev-requirements.txt: dev-requirements.in requirements.txt _install-piptools + $(call PIP_COMPILE,dev-requirements.in) + +.PHONY: requirements +requirements: requirements.txt dev-requirements.txt ## Compile requirements diff --git a/README.md b/README.md index 8ff429bf0d..a3c5bf0dbc 100644 --- a/README.md +++ b/README.md @@ -81,23 +81,53 @@ Or install them with the `all` directive. `all` defaults to Spark 2.4.x currentl pip install "flytekit[all]" ``` -## Testing +## Development -Flytekit is Python 2.7+ compatible, so when feasible, it is recommended to test with both Python 2 and 3. +### Recipes -### Unit Testing +``` +$ make +Available recipes: + setup Install requirements + fmt Format code with black and isort + lint Run linters + test Run tests + requirements Compile requirements +``` + +### Setup (Do Once) -#### Setup (Do Once) ```bash virtualenv ~/.virtualenvs/flytekit source ~/.virtualenvs/flytekit/bin/activate -python -m pip install -r requirements.txt -python -m pip install -U ".[all]" +make setup +``` + +### Formatting + +We use [black](https://github.com/psf/black) and [isort](https://github.com/timothycrosley/isort) to autoformat code. Run the following command to execute the formatters: + +```bash +source ~/.virtualenvs/flytekit/bin/activate +make fmt ``` -#### Execute +### Testing + +#### Unit Testing + ```bash source ~/.virtualenvs/flytekit/bin/activate -python -m pytest tests/flytekit/unit -shellcheck **/*.sh +make test ``` + +### Updating requirements + +Update requirements in [`setup.py`](setup.py), or update requirements for development in [`dev-requirements.in`](dev-requirements.in). Then, validate, pin and freeze all requirements by running: + +```bash +source ~/.virtualenvs/flytekit/bin/activate +make requirements +``` + +This will re-create the [`requirements.txt`](requirements.txt) and [`dev-requirements.txt`](dev-requirements.txt) files which will be used for testing. You will have also have to re-run `make setup` to update your local environment with the updated requirements. diff --git a/dev-requirements.in b/dev-requirements.in new file mode 100644 index 0000000000..008bb684c3 --- /dev/null +++ b/dev-requirements.in @@ -0,0 +1,10 @@ +-c requirements.txt + +black +coverage +flake8 +flake8-black +flake8-isort +isort +mock +pytest diff --git a/dev-requirements.txt b/dev-requirements.txt new file mode 100644 index 0000000000..5a0ea76ee7 --- /dev/null +++ b/dev-requirements.txt @@ -0,0 +1,34 @@ +# +# This file is autogenerated by pip-compile +# To update, run: +# +# make dev-requirements.txt +# +appdirs==1.4.4 # via -c requirements.txt, black +attrs==19.3.0 # via -c requirements.txt, black, pytest +black==19.10b0 # via -c requirements.txt, -r dev-requirements.in, flake8-black +click==7.1.2 # via -c requirements.txt, black +coverage==5.2.1 # via -r dev-requirements.in +flake8-black==0.2.1 # via -r dev-requirements.in +flake8-isort==4.0.0 # via -r dev-requirements.in +flake8==3.8.3 # via -r dev-requirements.in, flake8-black, flake8-isort +importlib-metadata==1.7.0 # via -c requirements.txt, flake8, pluggy, pytest +iniconfig==1.0.1 # via pytest +isort==5.3.2 # via -r dev-requirements.in, flake8-isort +mccabe==0.6.1 # via flake8 +mock==4.0.2 # via -r dev-requirements.in +more-itertools==8.4.0 # via pytest +packaging==20.4 # via pytest +pathspec==0.8.0 # via -c requirements.txt, black +pluggy==0.13.1 # via pytest +py==1.9.0 # via pytest +pycodestyle==2.6.0 # via flake8 +pyflakes==2.2.0 # via flake8 +pyparsing==2.4.7 # via packaging +pytest==6.0.1 # via -r dev-requirements.in +regex==2020.7.14 # via -c requirements.txt, black +six==1.15.0 # via -c requirements.txt, packaging +testfixtures==6.14.1 # via flake8-isort +toml==0.10.1 # via -c requirements.txt, black, pytest +typed-ast==1.4.1 # via -c requirements.txt, black +zipp==3.1.0 # via -c requirements.txt, importlib-metadata diff --git a/flytekit/__init__.py b/flytekit/__init__.py index 401482eaee..b87306995a 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -1,5 +1,5 @@ from __future__ import absolute_import -import flytekit.plugins +import flytekit.plugins # noqa: F401 -__version__ = '0.11.7' +__version__ = "0.12.0b0" diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index 4d81b3be47..7ad447de4a 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -1,19 +1,21 @@ from __future__ import absolute_import +import datetime as _datetime import importlib as _importlib import os as _os +import random as _random import click as _click -import datetime as _datetime -import random as _random from flyteidl.core import literals_pb2 as _literals_pb2 from flytekit.common import utils as _utils -from flytekit.common.exceptions import scopes as _scopes, system as _system_exceptions -from flytekit.configuration import internal as _internal_config, TemporaryConfiguration as _TemporaryConfiguration +from flytekit.common.exceptions import scopes as _scopes +from flytekit.common.exceptions import system as _system_exceptions +from flytekit.configuration import TemporaryConfiguration as _TemporaryConfiguration +from flytekit.configuration import internal as _internal_config from flytekit.engines import loader as _engine_loader -from flytekit.interfaces.data import data_proxy as _data_proxy from flytekit.interfaces import random as _flyte_random +from flytekit.interfaces.data import data_proxy as _data_proxy from flytekit.models import literals as _literal_models @@ -26,14 +28,14 @@ def _compute_array_job_index(): :rtype: int """ offset = 0 - if _os.environ.get('BATCH_JOB_ARRAY_INDEX_OFFSET'): - offset = int(_os.environ.get('BATCH_JOB_ARRAY_INDEX_OFFSET')) - return offset + int(_os.environ.get(_os.environ.get('BATCH_JOB_ARRAY_INDEX_VAR_NAME'))) + if _os.environ.get("BATCH_JOB_ARRAY_INDEX_OFFSET"): + offset = int(_os.environ.get("BATCH_JOB_ARRAY_INDEX_OFFSET")) + return offset + int(_os.environ.get(_os.environ.get("BATCH_JOB_ARRAY_INDEX_VAR_NAME"))) def _map_job_index_to_child_index(local_input_dir, datadir, index): - local_lookup_file = local_input_dir.get_named_tempfile('indexlookup.pb') - idx_lookup_file = _os.path.join(datadir, 'indexlookup.pb') + local_lookup_file = local_input_dir.get_named_tempfile("indexlookup.pb") + idx_lookup_file = _os.path.join(datadir, "indexlookup.pb") # if the indexlookup.pb does not exist, then just return the index if not _data_proxy.Data.data_exists(idx_lookup_file): @@ -44,47 +46,44 @@ def _map_job_index_to_child_index(local_input_dir, datadir, index): if len(mapping_proto.literals) < index: raise _system_exceptions.FlyteSystemAssertion( "dynamic task index lookup array size: {} is smaller than lookup index {}".format( - len(mapping_proto.literals), index)) + len(mapping_proto.literals), index + ) + ) return mapping_proto.literals[index].scalar.primitive.integer @_scopes.system_entry_point def _execute_task(task_module, task_name, inputs, output_prefix, test): with _TemporaryConfiguration(_internal_config.CONFIGURATION_PATH.get()): - with _utils.AutoDeletingTempDir('input_dir') as input_dir: + with _utils.AutoDeletingTempDir("input_dir") as input_dir: # Load user code task_module = _importlib.import_module(task_module) task_def = getattr(task_module, task_name) if not test: - local_inputs_file = input_dir.get_named_tempfile('inputs.pb') + local_inputs_file = input_dir.get_named_tempfile("inputs.pb") # Handle inputs/outputs for array job. - if _os.environ.get('BATCH_JOB_ARRAY_INDEX_VAR_NAME'): + if _os.environ.get("BATCH_JOB_ARRAY_INDEX_VAR_NAME"): job_index = _compute_array_job_index() # TODO: Perhaps remove. This is a workaround to an issue we perceived with limited entropy in # TODO: AWS batch array jobs. _flyte_random.seed_flyte_random( - "{} {} {}".format( - _random.random(), - _datetime.datetime.utcnow(), - job_index - ) + "{} {} {}".format(_random.random(), _datetime.datetime.utcnow(), job_index) ) # If an ArrayTask is discoverable, the original job index may be different than the one specified in # the environment variable. Look up the correct input/outputs in the index lookup mapping file. job_index = _map_job_index_to_child_index(input_dir, inputs, job_index) - inputs = _os.path.join(inputs, str(job_index), 'inputs.pb') + inputs = _os.path.join(inputs, str(job_index), "inputs.pb") output_prefix = _os.path.join(output_prefix, str(job_index)) _data_proxy.Data.get_data(inputs, local_inputs_file) input_proto = _utils.load_proto_from_file(_literals_pb2.LiteralMap, local_inputs_file) _engine_loader.get_engine().get_task(task_def).execute( - _literal_models.LiteralMap.from_flyte_idl(input_proto), - context={'output_prefix': output_prefix} + _literal_models.LiteralMap.from_flyte_idl(input_proto), context={"output_prefix": output_prefix}, ) @@ -93,16 +92,16 @@ def _pass_through(): pass -@_pass_through.command('pyflyte-execute') -@_click.option('--task-module', required=True) -@_click.option('--task-name', required=True) -@_click.option('--inputs', required=True) -@_click.option('--output-prefix', required=True) -@_click.option('--test', is_flag=True) +@_pass_through.command("pyflyte-execute") +@_click.option("--task-module", required=True) +@_click.option("--task-name", required=True) +@_click.option("--inputs", required=True) +@_click.option("--output-prefix", required=True) +@_click.option("--test", is_flag=True) def execute_task_cmd(task_module, task_name, inputs, output_prefix, test): _click.echo(_utils.get_version_message()) _execute_task(task_module, task_name, inputs, output_prefix, test) -if __name__ == '__main__': +if __name__ == "__main__": _pass_through() diff --git a/flytekit/clients/friendly.py b/flytekit/clients/friendly.py index 329be5991e..a89b1ae374 100644 --- a/flytekit/clients/friendly.py +++ b/flytekit/clients/friendly.py @@ -1,23 +1,30 @@ from __future__ import absolute_import import six as _six - -from flyteidl.admin import task_pb2 as _task_pb2, common_pb2 as _common_pb2, workflow_pb2 as _workflow_pb2, \ - launch_plan_pb2 as _launch_plan_pb2, execution_pb2 as _execution_pb2, node_execution_pb2 as _node_execution_pb2, \ - task_execution_pb2 as _task_execution_pb2, project_pb2 as _project_pb2, project_domain_attributes_pb2 as \ - _project_domain_attributes_pb2, workflow_attributes_pb2 as _workflow_attributes_pb2 -from flyteidl.core import identifier_pb2 as _identifier_pb2 +from flyteidl.admin import common_pb2 as _common_pb2 +from flyteidl.admin import execution_pb2 as _execution_pb2 +from flyteidl.admin import launch_plan_pb2 as _launch_plan_pb2 +from flyteidl.admin import node_execution_pb2 as _node_execution_pb2 +from flyteidl.admin import project_domain_attributes_pb2 as _project_domain_attributes_pb2 +from flyteidl.admin import project_pb2 as _project_pb2 +from flyteidl.admin import task_execution_pb2 as _task_execution_pb2 +from flyteidl.admin import task_pb2 as _task_pb2 +from flyteidl.admin import workflow_attributes_pb2 as _workflow_attributes_pb2 +from flyteidl.admin import workflow_pb2 as _workflow_pb2 from flytekit.clients.raw import RawSynchronousFlyteClient as _RawSynchronousFlyteClient -from flytekit.models import filters as _filters, common as _common, launch_plan as _launch_plan, task as _task, \ - execution as _execution, node_execution as _node_execution +from flytekit.models import common as _common +from flytekit.models import execution as _execution +from flytekit.models import filters as _filters +from flytekit.models import launch_plan as _launch_plan +from flytekit.models import node_execution as _node_execution +from flytekit.models import task as _task +from flytekit.models.admin import task_execution as _task_execution +from flytekit.models.admin import workflow as _workflow from flytekit.models.core import identifier as _identifier -from flytekit.models.admin import workflow as _workflow, task_execution as _task_execution -from flytekit.common.exceptions.user import FlyteAssertion as _FlyteAssertion class SynchronousFlyteClient(_RawSynchronousFlyteClient): - @property def raw(self): """ @@ -32,11 +39,7 @@ def raw(self): # #################################################################################################################### - def create_task( - self, - task_identifer, - task_spec - ): + def create_task(self, task_identifer, task_spec): """ This will create a task definition in the Admin database. Once successful, the task object can be retrieved via the client or viewed via the UI or command-line interfaces. @@ -56,20 +59,10 @@ def create_task( :raises grpc.RpcError: """ super(SynchronousFlyteClient, self).create_task( - _task_pb2.TaskCreateRequest( - id=task_identifer.to_flyte_idl(), - spec=task_spec.to_flyte_idl() - ) + _task_pb2.TaskCreateRequest(id=task_identifer.to_flyte_idl(), spec=task_spec.to_flyte_idl()) ) - def list_task_ids_paginated( - self, - project, - domain, - limit=100, - token=None, - sort_by=None - ): + def list_task_ids_paginated(self, project, domain, limit=100, token=None, sort_by=None): """ This returns a page of identifiers for the tasks for a given project and domain. Filters can also be specified. @@ -102,22 +95,15 @@ def list_task_ids_paginated( domain=domain, limit=limit, token=token, - sort_by=None if sort_by is None else sort_by.to_flyte_idl() + sort_by=None if sort_by is None else sort_by.to_flyte_idl(), ) ) - return [ - _common.NamedEntityIdentifier.from_flyte_idl(identifier_pb) - for identifier_pb in identifier_list.entities - ], _six.text_type(identifier_list.token) - - def list_tasks_paginated( - self, - identifier, - limit=100, - token=None, - filters=None, - sort_by=None - ): + return ( + [_common.NamedEntityIdentifier.from_flyte_idl(identifier_pb) for identifier_pb in identifier_list.entities], + _six.text_type(identifier_list.token), + ) + + def list_tasks_paginated(self, identifier, limit=100, token=None, filters=None, sort_by=None): """ This returns a page of task metadata for tasks in a given project and domain. Optionally, specifying a name will limit the results to only tasks with that name in the given project and domain. @@ -151,13 +137,16 @@ def list_tasks_paginated( limit=limit, token=token, filters=_filters.FilterList(filters or []).to_flyte_idl(), - sort_by=None if sort_by is None else sort_by.to_flyte_idl() + sort_by=None if sort_by is None else sort_by.to_flyte_idl(), ) ) # TODO: tmp workaround for pb in task_list.tasks: pb.id.resource_type = _identifier.ResourceType.TASK - return [_task.Task.from_flyte_idl(task_pb2) for task_pb2 in task_list.tasks], _six.text_type(task_list.token) + return ( + [_task.Task.from_flyte_idl(task_pb2) for task_pb2 in task_list.tasks], + _six.text_type(task_list.token), + ) def get_task(self, id): """ @@ -168,11 +157,7 @@ def get_task(self, id): :rtype: flytekit.models.task.Task """ return _task.Task.from_flyte_idl( - super(SynchronousFlyteClient, self).get_task( - _common_pb2.ObjectGetRequest( - id=id.to_flyte_idl() - ) - ) + super(SynchronousFlyteClient, self).get_task(_common_pb2.ObjectGetRequest(id=id.to_flyte_idl())) ) #################################################################################################################### @@ -181,11 +166,7 @@ def get_task(self, id): # #################################################################################################################### - def create_workflow( - self, - workflow_identifier, - workflow_spec - ): + def create_workflow(self, workflow_identifier, workflow_spec): """ This will create a workflow definition in the Admin database. Once successful, the workflow object can be retrieved via the client or viewed via the UI or command-line interfaces. @@ -206,19 +187,11 @@ def create_workflow( """ super(SynchronousFlyteClient, self).create_workflow( _workflow_pb2.WorkflowCreateRequest( - id=workflow_identifier.to_flyte_idl(), - spec=workflow_spec.to_flyte_idl() + id=workflow_identifier.to_flyte_idl(), spec=workflow_spec.to_flyte_idl() ) ) - def list_workflow_ids_paginated( - self, - project, - domain, - limit=100, - token=None, - sort_by=None - ): + def list_workflow_ids_paginated(self, project, domain, limit=100, token=None, sort_by=None): """ This returns a page of identifiers for the workflows for a given project and domain. Filters can also be specified. @@ -251,22 +224,15 @@ def list_workflow_ids_paginated( domain=domain, limit=limit, token=token, - sort_by=None if sort_by is None else sort_by.to_flyte_idl() + sort_by=None if sort_by is None else sort_by.to_flyte_idl(), ) ) - return [ - _common.NamedEntityIdentifier.from_flyte_idl(identifier_pb) - for identifier_pb in identifier_list.entities - ], _six.text_type(identifier_list.token) - - def list_workflows_paginated( - self, - identifier, - limit=100, - token=None, - filters=None, - sort_by=None - ): + return ( + [_common.NamedEntityIdentifier.from_flyte_idl(identifier_pb) for identifier_pb in identifier_list.entities], + _six.text_type(identifier_list.token), + ) + + def list_workflows_paginated(self, identifier, limit=100, token=None, filters=None, sort_by=None): """ This returns a page of workflow meta-information for workflows in a given project and domain. Optionally, specifying a name will limit the results to only workflows with that name in the given project and domain. @@ -300,14 +266,16 @@ def list_workflows_paginated( limit=limit, token=token, filters=_filters.FilterList(filters or []).to_flyte_idl(), - sort_by=None if sort_by is None else sort_by.to_flyte_idl() + sort_by=None if sort_by is None else sort_by.to_flyte_idl(), ) ) # TODO: tmp workaround for pb in wf_list.workflows: pb.id.resource_type = _identifier.ResourceType.WORKFLOW - return [_workflow.Workflow.from_flyte_idl(wf_pb2) for wf_pb2 in wf_list.workflows], \ - _six.text_type(wf_list.token) + return ( + [_workflow.Workflow.from_flyte_idl(wf_pb2) for wf_pb2 in wf_list.workflows], + _six.text_type(wf_list.token), + ) def get_workflow(self, id): """ @@ -318,11 +286,7 @@ def get_workflow(self, id): :rtype: flytekit.models.admin.workflow.Workflow """ return _workflow.Workflow.from_flyte_idl( - super(SynchronousFlyteClient, self).get_workflow( - _common_pb2.ObjectGetRequest( - id=id.to_flyte_idl() - ) - ) + super(SynchronousFlyteClient, self).get_workflow(_common_pb2.ObjectGetRequest(id=id.to_flyte_idl())) ) #################################################################################################################### @@ -331,11 +295,7 @@ def get_workflow(self, id): # #################################################################################################################### - def create_launch_plan( - self, - launch_plan_identifer, - launch_plan_spec - ): + def create_launch_plan(self, launch_plan_identifer, launch_plan_spec): """ This will create a launch plan definition in the Admin database. Once successful, the launch plan object can be retrieved via the client or viewed via the UI or command-line interfaces. @@ -356,8 +316,7 @@ def create_launch_plan( """ super(SynchronousFlyteClient, self).create_launch_plan( _launch_plan_pb2.LaunchPlanCreateRequest( - id=launch_plan_identifer.to_flyte_idl(), - spec=launch_plan_spec.to_flyte_idl() + id=launch_plan_identifer.to_flyte_idl(), spec=launch_plan_spec.to_flyte_idl(), ) ) @@ -369,11 +328,7 @@ def get_launch_plan(self, id): :rtype: flytekit.models.launch_plan.LaunchPlan """ return _launch_plan.LaunchPlan.from_flyte_idl( - super(SynchronousFlyteClient, self).get_launch_plan( - _common_pb2.ObjectGetRequest( - id=id.to_flyte_idl() - ) - ) + super(SynchronousFlyteClient, self).get_launch_plan(_common_pb2.ObjectGetRequest(id=id.to_flyte_idl())) ) def get_active_launch_plan(self, identifier): @@ -386,20 +341,11 @@ def get_active_launch_plan(self, identifier): """ return _launch_plan.LaunchPlan.from_flyte_idl( super(SynchronousFlyteClient, self).get_active_launch_plan( - _launch_plan_pb2.ActiveLaunchPlanRequest( - id=identifier.to_flyte_idl() - ) + _launch_plan_pb2.ActiveLaunchPlanRequest(id=identifier.to_flyte_idl()) ) ) - def list_launch_plan_ids_paginated( - self, - project, - domain, - limit=100, - token=None, - sort_by=None - ): + def list_launch_plan_ids_paginated(self, project, domain, limit=100, token=None, sort_by=None): """ This returns a page of identifiers for the launch plans for a given project and domain. Filters can also be specified. @@ -432,22 +378,15 @@ def list_launch_plan_ids_paginated( domain=domain, limit=limit, token=token, - sort_by=None if sort_by is None else sort_by.to_flyte_idl() + sort_by=None if sort_by is None else sort_by.to_flyte_idl(), ) ) - return [ - _common.NamedEntityIdentifier.from_flyte_idl(identifier_pb) - for identifier_pb in identifier_list.entities - ], _six.text_type(identifier_list.token) - - def list_launch_plans_paginated( - self, - identifier, - limit=100, - token=None, - filters=None, - sort_by=None - ): + return ( + [_common.NamedEntityIdentifier.from_flyte_idl(identifier_pb) for identifier_pb in identifier_list.entities], + _six.text_type(identifier_list.token), + ) + + def list_launch_plans_paginated(self, identifier, limit=100, token=None, filters=None, sort_by=None): """ This returns a page of launch plan meta-information for launch plans in a given project and domain. Optionally, specifying a name will limit the results to only workflows with that name in the given project and domain. @@ -481,23 +420,18 @@ def list_launch_plans_paginated( limit=limit, token=token, filters=_filters.FilterList(filters or []).to_flyte_idl(), - sort_by=None if sort_by is None else sort_by.to_flyte_idl() + sort_by=None if sort_by is None else sort_by.to_flyte_idl(), ) ) # TODO: tmp workaround for pb in lp_list.launch_plans: pb.id.resource_type = _identifier.ResourceType.LAUNCH_PLAN - return [_launch_plan.LaunchPlan.from_flyte_idl(pb) for pb in lp_list.launch_plans], \ - _six.text_type(lp_list.token) - - def list_active_launch_plans_paginated( - self, - project, - domain, - limit=100, - token=None, - sort_by=None - ): + return ( + [_launch_plan.LaunchPlan.from_flyte_idl(pb) for pb in lp_list.launch_plans], + _six.text_type(lp_list.token), + ) + + def list_active_launch_plans_paginated(self, project, domain, limit=100, token=None, sort_by=None): """ This returns a page of currently active launch plan meta-information for launch plans in a given project and domain. @@ -530,14 +464,16 @@ def list_active_launch_plans_paginated( domain=domain, limit=limit, token=token, - sort_by=None if sort_by is None else sort_by.to_flyte_idl() + sort_by=None if sort_by is None else sort_by.to_flyte_idl(), ) ) # TODO: tmp workaround for pb in lp_list.launch_plans: pb.id.resource_type = _identifier.ResourceType.LAUNCH_PLAN - return [_launch_plan.LaunchPlan.from_flyte_idl(pb) for pb in lp_list.launch_plans], \ - _six.text_type(lp_list.token) + return ( + [_launch_plan.LaunchPlan.from_flyte_idl(pb) for pb in lp_list.launch_plans], + _six.text_type(lp_list.token), + ) def update_launch_plan(self, id, state): """ @@ -550,10 +486,7 @@ def update_launch_plan(self, id, state): :param int state: Enum value from flytekit.models.launch_plan.LaunchPlanState """ super(SynchronousFlyteClient, self).update_launch_plan( - _launch_plan_pb2.LaunchPlanUpdateRequest( - id=id.to_flyte_idl(), - state=state - ) + _launch_plan_pb2.LaunchPlanUpdateRequest(id=id.to_flyte_idl(), state=state) ) #################################################################################################################### @@ -573,9 +506,7 @@ def update_named_entity(self, resource_type, id, metadata): """ super(SynchronousFlyteClient, self).update_named_entity( _common_pb2.NamedEntityUpdateRequest( - resource_type=resource_type, - id=id.to_flyte_idl(), - metadata=metadata.to_flyte_idl(), + resource_type=resource_type, id=id.to_flyte_idl(), metadata=metadata.to_flyte_idl(), ) ) @@ -597,7 +528,8 @@ def create_execution(self, project, domain, name, execution_spec, inputs): :rtype: flytekit.models.core.identifier.WorkflowExecutionIdentifier """ return _identifier.WorkflowExecutionIdentifier.from_flyte_idl( - super(SynchronousFlyteClient, self).create_execution( + super(SynchronousFlyteClient, self) + .create_execution( _execution_pb2.ExecutionCreateRequest( project=project, domain=domain, @@ -605,7 +537,8 @@ def create_execution(self, project, domain, name, execution_spec, inputs): spec=execution_spec.to_flyte_idl(), inputs=inputs.to_flyte_idl(), ) - ).id + ) + .id ) def get_execution(self, id): @@ -615,9 +548,7 @@ def get_execution(self, id): """ return _execution.Execution.from_flyte_idl( super(SynchronousFlyteClient, self).get_execution( - _execution_pb2.WorkflowExecutionGetRequest( - id=id.to_flyte_idl() - ) + _execution_pb2.WorkflowExecutionGetRequest(id=id.to_flyte_idl()) ) ) @@ -630,21 +561,11 @@ def get_execution_data(self, id): """ return _execution.WorkflowExecutionGetDataResponse.from_flyte_idl( super(SynchronousFlyteClient, self).get_execution_data( - _execution_pb2.WorkflowExecutionGetDataRequest( - id=id.to_flyte_idl() - ) + _execution_pb2.WorkflowExecutionGetDataRequest(id=id.to_flyte_idl()) ) ) - def list_executions_paginated( - self, - project, - domain, - limit=100, - token=None, - filters=None, - sort_by=None - ): + def list_executions_paginated(self, project, domain, limit=100, token=None, filters=None, sort_by=None): """ This returns a page of executions in a given project and domain. @@ -674,17 +595,17 @@ def list_executions_paginated( """ exec_list = super(SynchronousFlyteClient, self).list_executions_paginated( resource_list_request=_common_pb2.ResourceListRequest( - id=_common_pb2.NamedEntityIdentifier( - project=project, - domain=domain - ), + id=_common_pb2.NamedEntityIdentifier(project=project, domain=domain), limit=limit, token=token, filters=_filters.FilterList(filters or []).to_flyte_idl(), - sort_by=None if sort_by is None else sort_by.to_flyte_idl() + sort_by=None if sort_by is None else sort_by.to_flyte_idl(), ) ) - return [_execution.Execution.from_flyte_idl(pb) for pb in exec_list.executions], _six.text_type(exec_list.token) + return ( + [_execution.Execution.from_flyte_idl(pb) for pb in exec_list.executions], + _six.text_type(exec_list.token), + ) def terminate_execution(self, id, cause): """ @@ -692,10 +613,7 @@ def terminate_execution(self, id, cause): :param Text cause: """ super(SynchronousFlyteClient, self).terminate_execution( - _execution_pb2.ExecutionTerminateRequest( - id=id.to_flyte_idl(), - cause=cause - ) + _execution_pb2.ExecutionTerminateRequest(id=id.to_flyte_idl(), cause=cause) ) def relaunch_execution(self, id, name=None): @@ -707,12 +625,9 @@ def relaunch_execution(self, id, name=None): :rtype: flytekit.models.core.identifier.WorkflowExecutionIdentifier """ return _identifier.WorkflowExecutionIdentifier.from_flyte_idl( - super(SynchronousFlyteClient, self).relaunch_execution( - _execution_pb2.ExecutionRelaunchRequest( - id=id.to_flyte_idl(), - name=name - ) - ).id + super(SynchronousFlyteClient, self) + .relaunch_execution(_execution_pb2.ExecutionRelaunchRequest(id=id.to_flyte_idl(), name=name)) + .id ) #################################################################################################################### @@ -728,9 +643,7 @@ def get_node_execution(self, node_execution_identifier): """ return _node_execution.NodeExecution.from_flyte_idl( super(SynchronousFlyteClient, self).get_node_execution( - _node_execution_pb2.NodeExecutionGetRequest( - id=node_execution_identifier.to_flyte_idl() - ) + _node_execution_pb2.NodeExecutionGetRequest(id=node_execution_identifier.to_flyte_idl()) ) ) @@ -743,19 +656,12 @@ def get_node_execution_data(self, node_execution_identifier): """ return _execution.NodeExecutionGetDataResponse.from_flyte_idl( super(SynchronousFlyteClient, self).get_node_execution_data( - _node_execution_pb2.NodeExecutionGetDataRequest( - id=node_execution_identifier.to_flyte_idl() - ) + _node_execution_pb2.NodeExecutionGetDataRequest(id=node_execution_identifier.to_flyte_idl()) ) ) def list_node_executions( - self, - workflow_execution_identifier, - limit=100, - token=None, - filters=None, - sort_by=None + self, workflow_execution_identifier, limit=100, token=None, filters=None, sort_by=None, ): """ TODO: Comment @@ -774,19 +680,16 @@ def list_node_executions( limit=limit, token=token, filters=_filters.FilterList(filters or []).to_flyte_idl(), - sort_by=None if sort_by is None else sort_by.to_flyte_idl() + sort_by=None if sort_by is None else sort_by.to_flyte_idl(), ) ) - return [_node_execution.NodeExecution.from_flyte_idl(e) for e in exec_list.node_executions], \ - _six.text_type(exec_list.token) + return ( + [_node_execution.NodeExecution.from_flyte_idl(e) for e in exec_list.node_executions], + _six.text_type(exec_list.token), + ) def list_node_executions_for_task_paginated( - self, - task_execution_identifier, - limit=100, - token=None, - filters=None, - sort_by=None + self, task_execution_identifier, limit=100, token=None, filters=None, sort_by=None, ): """ This returns nodes spawned by a specific task execution. This is generally from things like dynamic tasks. @@ -805,11 +708,13 @@ def list_node_executions_for_task_paginated( limit=limit, token=token, filters=_filters.FilterList(filters or []).to_flyte_idl(), - sort_by=None if sort_by is None else sort_by.to_flyte_idl() + sort_by=None if sort_by is None else sort_by.to_flyte_idl(), ) ) - return [_node_execution.NodeExecution.from_flyte_idl(e) for e in exec_list.node_executions], \ - _six.text_type(exec_list.token) + return ( + [_node_execution.NodeExecution.from_flyte_idl(e) for e in exec_list.node_executions], + _six.text_type(exec_list.token), + ) #################################################################################################################### # @@ -824,9 +729,7 @@ def get_task_execution(self, id): """ return _task_execution.TaskExecution.from_flyte_idl( super(SynchronousFlyteClient, self).get_task_execution( - _task_execution_pb2.TaskExecutionGetRequest( - id=id.to_flyte_idl() - ) + _task_execution_pb2.TaskExecutionGetRequest(id=id.to_flyte_idl()) ) ) @@ -839,14 +742,13 @@ def get_task_execution_data(self, task_execution_identifier): """ return _execution.TaskExecutionGetDataResponse.from_flyte_idl( super(SynchronousFlyteClient, self).get_task_execution_data( - _task_execution_pb2.TaskExecutionGetDataRequest( - id=task_execution_identifier.to_flyte_idl() - ) + _task_execution_pb2.TaskExecutionGetDataRequest(id=task_execution_identifier.to_flyte_idl()) ) ) - def list_task_executions_paginated(self, node_execution_identifier, limit=100, token=None, filters=None, - sort_by=None): + def list_task_executions_paginated( + self, node_execution_identifier, limit=100, token=None, filters=None, sort_by=None, + ): """ :param flytekit.models.core.identifier.NodeExecutionIdentifier node_execution_identifier: :param int limit: @@ -863,11 +765,13 @@ def list_task_executions_paginated(self, node_execution_identifier, limit=100, t limit=limit, token=token, filters=_filters.FilterList(filters or []).to_flyte_idl(), - sort_by=None if sort_by is None else sort_by.to_flyte_idl() + sort_by=None if sort_by is None else sort_by.to_flyte_idl(), ) ) - return [_task_execution.TaskExecution.from_flyte_idl(e) for e in exec_list.task_executions], \ - _six.text_type(exec_list.token) + return ( + [_task_execution.TaskExecution.from_flyte_idl(e) for e in exec_list.task_executions], + _six.text_type(exec_list.token), + ) #################################################################################################################### # @@ -882,9 +786,7 @@ def register_project(self, project): :rtype: flyteidl.admin.project_pb2.ProjectRegisterResponse """ super(SynchronousFlyteClient, self).register_project( - _project_pb2.ProjectRegisterRequest( - project=project.to_flyte_idl(), - ) + _project_pb2.ProjectRegisterRequest(project=project.to_flyte_idl(),) ) #################################################################################################################### @@ -904,9 +806,7 @@ def update_project_domain_attributes(self, project, domain, matching_attributes) super(SynchronousFlyteClient, self).update_project_domain_attributes( _project_domain_attributes_pb2.ProjectDomainAttributesUpdateRequest( attributes=_project_domain_attributes_pb2.ProjectDomainAttributes( - project=project, - domain=domain, - matching_attributes=matching_attributes.to_flyte_idl(), + project=project, domain=domain, matching_attributes=matching_attributes.to_flyte_idl(), ) ) ) diff --git a/flytekit/clients/helpers.py b/flytekit/clients/helpers.py index 10640b6d74..75b2232636 100644 --- a/flytekit/clients/helpers.py +++ b/flytekit/clients/helpers.py @@ -1,14 +1,5 @@ - -from flytekit.clis.auth import credentials as _credentials_access - - - def iterate_node_executions( - client, - workflow_execution_identifier=None, - task_execution_identifier=None, - limit=None, - filters=None + client, workflow_execution_identifier=None, task_execution_identifier=None, limit=None, filters=None, ): """ This returns a generator for node executions. @@ -30,14 +21,11 @@ def iterate_node_executions( workflow_execution_identifier=workflow_execution_identifier, limit=num_to_fetch, token=token, - filters=filters + filters=filters, ) else: node_execs, next_token = client.list_node_executions_for_task_paginated( - task_execution_identifier=task_execution_identifier, - limit=num_to_fetch, - token=token, - filters=filters + task_execution_identifier=task_execution_identifier, limit=num_to_fetch, token=token, filters=filters, ) for n in node_execs: counter += 1 @@ -65,10 +53,7 @@ def iterate_task_executions(client, node_execution_identifier, limit=None, filte counter = 0 while True: task_execs, next_token = client.list_task_executions_paginated( - node_execution_identifier=node_execution_identifier, - limit=num_to_fetch, - token=token, - filters=filters + node_execution_identifier=node_execution_identifier, limit=num_to_fetch, token=token, filters=filters, ) for t in task_execs: counter += 1 @@ -78,4 +63,3 @@ def iterate_task_executions(client, node_execution_identifier, limit=None, filte if not next_token: break token = next_token - diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py index 6a596769a0..d1c89c820b 100644 --- a/flytekit/clients/raw.py +++ b/flytekit/clients/raw.py @@ -1,21 +1,23 @@ from __future__ import absolute_import -from grpc import insecure_channel as _insecure_channel, secure_channel as _secure_channel, RpcError as _RpcError, \ - StatusCode as _GrpcStatusCode, ssl_channel_credentials as _ssl_channel_credentials -from google.protobuf.json_format import MessageToJson as _MessageToJson -from flyteidl.service import admin_pb2_grpc as _admin_service -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.configuration.platform import AUTH as _AUTH -from flytekit.configuration.creds import ( - CLIENT_ID as _CLIENT_ID, - CLIENT_CREDENTIALS_SCOPE as _SCOPE, -) -from flytekit.clis.sdk_in_container import basic_auth as _basic_auth import logging as _logging + import six as _six -from flytekit.configuration import creds as _creds_config, platform as _platform_config +from flyteidl.service import admin_pb2_grpc as _admin_service +from google.protobuf.json_format import MessageToJson as _MessageToJson +from grpc import RpcError as _RpcError +from grpc import StatusCode as _GrpcStatusCode +from grpc import insecure_channel as _insecure_channel +from grpc import secure_channel as _secure_channel +from grpc import ssl_channel_credentials as _ssl_channel_credentials from flytekit.clis.auth import credentials as _credentials_access +from flytekit.clis.sdk_in_container import basic_auth as _basic_auth +from flytekit.common.exceptions import user as _user_exceptions +from flytekit.configuration import creds as _creds_config +from flytekit.configuration.creds import CLIENT_CREDENTIALS_SCOPE as _SCOPE +from flytekit.configuration.creds import CLIENT_ID as _CLIENT_ID +from flytekit.configuration.platform import AUTH as _AUTH def _refresh_credentials_standard(flyte_client): @@ -43,10 +45,10 @@ def _refresh_credentials_basic(flyte_client): auth_endpoints = _credentials_access.get_authorization_endpoints(flyte_client.url) token_endpoint = auth_endpoints.token_endpoint client_secret = _basic_auth.get_secret() - _logging.debug('Basic authorization flow with client id {} scope {}'.format(_CLIENT_ID.get(), _SCOPE.get())) + _logging.debug("Basic authorization flow with client id {} scope {}".format(_CLIENT_ID.get(), _SCOPE.get())) authorization_header = _basic_auth.get_basic_authorization_header(_CLIENT_ID.get(), client_secret) token, expires_in = _basic_auth.get_token(token_endpoint, authorization_header, _SCOPE.get()) - _logging.info('Retrieved new token, expires in {}'.format(expires_in)) + _logging.info("Retrieved new token, expires in {}".format(expires_in)) flyte_client.set_access_token(token) @@ -61,7 +63,8 @@ def _get_refresh_handler(auth_mode): return _refresh_credentials_basic else: raise ValueError( - "Invalid auth mode [{}] specified. Please update the creds config to use a valid value".format(auth_mode)) + "Invalid auth mode [{}] specified. Please update the creds config to use a valid value".format(auth_mode) + ) def _handle_rpc_error(fn): @@ -91,6 +94,7 @@ def handler(*args, **kwargs): raise _user_exceptions.FlyteEntityAlreadyExistsException(_six.text_type(e)) else: raise + return handler @@ -136,9 +140,7 @@ def __init__(self, url, insecure=False, credentials=None, options=None): self._channel = _insecure_channel(url, options=list((options or {}).items())) else: self._channel = _secure_channel( - url, - credentials or _ssl_channel_credentials(), - options=list((options or {}).items()) + url, credentials or _ssl_channel_credentials(), options=list((options or {}).items()), ) self._stub = _admin_service.AdminServiceStub(self._channel) self._metadata = None @@ -152,7 +154,7 @@ def url(self) -> str: def set_access_token(self, access_token): # Always set the header to lower-case regardless of what the config is. The grpc libraries that Admin uses # to parse the metadata don't change the metadata, but they do automatically lower the key you're looking for. - self._metadata = [(_creds_config.AUTHORIZATION_METADATA_KEY.get().lower(), "Bearer {}".format(access_token))] + self._metadata = [(_creds_config.AUTHORIZATION_METADATA_KEY.get().lower(), "Bearer {}".format(access_token),)] def force_auth_flow(self): refresh_handler_fn = _get_refresh_handler(_creds_config.AUTH_MODE.get()) @@ -607,8 +609,9 @@ def update_project_domain_attributes(self, project_domain_attributes_update_requ :param flyteidl.admin..ProjectDomainAttributesUpdateRequest project_domain_attributes_update_request: :rtype: flyteidl.admin..ProjectDomainAttributesUpdateResponse """ - return self._stub.UpdateProjectDomainAttributes(project_domain_attributes_update_request, - metadata=self._metadata) + return self._stub.UpdateProjectDomainAttributes( + project_domain_attributes_update_request, metadata=self._metadata + ) @_handle_rpc_error def update_workflow_attributes(self, workflow_attributes_update_request): diff --git a/flytekit/clis/auth/auth.py b/flytekit/clis/auth/auth.py index 33f201bac1..2afc4644e7 100644 --- a/flytekit/clis/auth/auth.py +++ b/flytekit/clis/auth/auth.py @@ -1,12 +1,13 @@ import base64 as _base64 import hashlib as _hashlib -import keyring as _keyring import os as _os import re as _re -import requests as _requests import webbrowser as _webbrowser +from multiprocessing import Process as _Process +from multiprocessing import Queue as _Queue -from multiprocessing import Process as _Process, Queue as _Queue +import keyring as _keyring +import requests as _requests try: # Python 3.5+ from http import HTTPStatus as _StatusCodes @@ -24,12 +25,13 @@ import urllib.parse as _urlparse from urllib.parse import urlencode as _urlencode except ImportError: # Python 2 - import urlparse as _urlparse from urllib import urlencode as _urlencode + import urlparse as _urlparse + _code_verifier_length = 64 _random_seed_length = 40 -_utf_8 = 'utf-8' +_utf_8 = "utf-8" # Identifies the service used for storing passwords in keyring @@ -48,7 +50,7 @@ def _generate_code_verifier(): """ code_verifier = _base64.urlsafe_b64encode(_os.urandom(_code_verifier_length)).decode(_utf_8) # Eliminate invalid characters. - code_verifier = _re.sub(r'[^a-zA-Z0-9_\-.~]+', '', code_verifier) + code_verifier = _re.sub(r"[^a-zA-Z0-9_\-.~]+", "", code_verifier) if len(code_verifier) < 43: raise ValueError("Verifier too short. number of bytes must be > 30.") elif len(code_verifier) > 128: @@ -59,7 +61,7 @@ def _generate_code_verifier(): def _generate_state_parameter(): state = _base64.urlsafe_b64encode(_os.urandom(_random_seed_length)).decode(_utf_8) # Eliminate invalid characters. - code_verifier = _re.sub('[^a-zA-Z0-9-_.,]+', '', state) + code_verifier = _re.sub("[^a-zA-Z0-9-_.,]+", "", state) return code_verifier @@ -72,7 +74,7 @@ def _create_code_challenge(code_verifier): code_challenge = _hashlib.sha256(code_verifier.encode(_utf_8)).digest() code_challenge = _base64.urlsafe_b64encode(code_challenge).decode(_utf_8) # Eliminate invalid characters - code_challenge = code_challenge.replace('=', '') + code_challenge = code_challenge.replace("=", "") return code_challenge @@ -106,7 +108,7 @@ def do_GET(self): self.send_response(_StatusCodes.NOT_FOUND) def handle_login(self, data): - self.server.handle_authorization_code(AuthorizationCode(data['code'], data['state'])) + self.server.handle_authorization_code(AuthorizationCode(data["code"], data["state"])) class OAuthHTTPServer(_BaseHTTPServer.HTTPServer): @@ -114,8 +116,10 @@ class OAuthHTTPServer(_BaseHTTPServer.HTTPServer): A simple wrapper around the BaseHTTPServer.HTTPServer implementation that binds an authorization_client for handling authorization code callbacks. """ - def __init__(self, server_address, RequestHandlerClass, bind_and_activate=True, - redirect_path=None, queue=None): + + def __init__( + self, server_address, RequestHandlerClass, bind_and_activate=True, redirect_path=None, queue=None, + ): _BaseHTTPServer.HTTPServer.__init__(self, server_address, RequestHandlerClass, bind_and_activate) self._redirect_path = redirect_path self._auth_code = None @@ -151,7 +155,7 @@ def __init__(self, auth_endpoint=None, token_endpoint=None, client_id=None, redi self._state = state self._credentials = None self._refresh_token = None - self._headers = {'content-type': "application/x-www-form-urlencoded"} + self._headers = {"content-type": "application/x-www-form-urlencoded"} self._expired = False self._params = { @@ -225,22 +229,18 @@ def _initialize_credentials(self, auth_token_resp): def request_access_token(self, auth_code): if self._state != auth_code.state: raise ValueError("Unexpected state parameter [{}] passed".format(auth_code.state)) - self._params.update({ - "code": auth_code.code, - "code_verifier": self._code_verifier, - "grant_type": "authorization_code", - }) + self._params.update( + {"code": auth_code.code, "code_verifier": self._code_verifier, "grant_type": "authorization_code"} + ) resp = _requests.post( - url=self._token_endpoint, - data=self._params, - headers=self._headers, - allow_redirects=False + url=self._token_endpoint, data=self._params, headers=self._headers, allow_redirects=False, ) if resp.status_code != _StatusCodes.OK: # TODO: handle expected (?) error cases: # https://auth0.com/docs/flows/guides/device-auth/call-api-device-auth#token-responses - raise Exception('Failed to request access token with response: [{}] {}'.format( - resp.status_code, resp.content)) + raise Exception( + "Failed to request access token with response: [{}] {}".format(resp.status_code, resp.content) + ) self._initialize_credentials(resp) def refresh_access_token(self): @@ -249,11 +249,9 @@ def refresh_access_token(self): resp = _requests.post( url=self._token_endpoint, - data={'grant_type': 'refresh_token', - 'client_id': self._client_id, - 'refresh_token': self._refresh_token}, + data={"grant_type": "refresh_token", "client_id": self._client_id, "refresh_token": self._refresh_token}, headers=self._headers, - allow_redirects=False + allow_redirects=False, ) if resp.status_code != _StatusCodes.OK: self._expired = True diff --git a/flytekit/clis/auth/credentials.py b/flytekit/clis/auth/credentials.py index 3e53e90375..05d41e9f42 100644 --- a/flytekit/clis/auth/credentials.py +++ b/flytekit/clis/auth/credentials.py @@ -1,18 +1,15 @@ from __future__ import absolute_import import logging as _logging +import urllib.parse as _urlparse from flytekit.clis.auth.auth import AuthorizationClient as _AuthorizationClient from flytekit.clis.auth.discovery import DiscoveryClient as _DiscoveryClient - -from flytekit.configuration.creds import ( - REDIRECT_URI as _REDIRECT_URI, - CLIENT_ID as _CLIENT_ID, -) -from flytekit.configuration.platform import URL as _URL, INSECURE as _INSECURE, HTTP_URL as _HTTP_URL - - -import urllib.parse as _urlparse +from flytekit.configuration.creds import CLIENT_ID as _CLIENT_ID +from flytekit.configuration.creds import REDIRECT_URI as _REDIRECT_URI +from flytekit.configuration.platform import HTTP_URL as _HTTP_URL +from flytekit.configuration.platform import INSECURE as _INSECURE +from flytekit.configuration.platform import URL as _URL # Default, well known-URI string used for fetching JSON metadata. See https://tools.ietf.org/html/rfc8414#section-3. discovery_endpoint_path = "./.well-known/oauth-authorization-server" @@ -23,17 +20,17 @@ def _get_discovery_endpoint(http_config_val, platform_url_val, insecure_val): if http_config_val: scheme, netloc, path, _, _, _ = _urlparse.urlparse(http_config_val) if not scheme: - scheme = 'http' if insecure_val else 'https' + scheme = "http" if insecure_val else "https" else: # Use the main _URL config object effectively - scheme = 'http' if insecure_val else 'https' + scheme = "http" if insecure_val else "https" netloc = platform_url_val - path = '' + path = "" computed_endpoint = _urlparse.urlunparse((scheme, netloc, path, None, None, None)) # The urljoin function needs a trailing slash in order to append things correctly. Also, having an extra slash # at the end is okay, it just gets stripped out. - computed_endpoint = _urlparse.urljoin(computed_endpoint + '/', discovery_endpoint_path) - _logging.info('Using {} as discovery endpoint'.format(computed_endpoint)) + computed_endpoint = _urlparse.urljoin(computed_endpoint + "/", discovery_endpoint_path) + _logging.info("Using {} as discovery endpoint".format(computed_endpoint)) return computed_endpoint @@ -47,10 +44,12 @@ def get_client(flyte_client_url): return _authorization_client authorization_endpoints = get_authorization_endpoints(flyte_client_url) - _authorization_client =\ - _AuthorizationClient(redirect_uri=_REDIRECT_URI.get(), client_id=_CLIENT_ID.get(), - auth_endpoint=authorization_endpoints.auth_endpoint, - token_endpoint=authorization_endpoints.token_endpoint) + _authorization_client = _AuthorizationClient( + redirect_uri=_REDIRECT_URI.get(), + client_id=_CLIENT_ID.get(), + auth_endpoint=authorization_endpoints.auth_endpoint, + token_endpoint=authorization_endpoints.token_endpoint, + ) return _authorization_client diff --git a/flytekit/clis/auth/discovery.py b/flytekit/clis/auth/discovery.py index d94d677eee..d661f6302e 100644 --- a/flytekit/clis/auth/discovery.py +++ b/flytekit/clis/auth/discovery.py @@ -1,13 +1,6 @@ -import requests as _requests import logging -try: # Python 3.5+ - from http import HTTPStatus as _StatusCodes -except ImportError: - try: # Python 3 - from http import client as _StatusCodes - except ImportError: # Python 2 - import httplib as _StatusCodes +import requests as _requests # These response keys are defined in https://tools.ietf.org/id/draft-ietf-oauth-discovery-08.html. _authorization_endpoint_key = "authorization_endpoint" @@ -18,6 +11,7 @@ class AuthorizationEndpoints(object): """ A simple wrapper around commonly discovered endpoints used for the PKCE auth flow. """ + def __init__(self, auth_endpoint=None, token_endpoint=None): self._auth_endpoint = auth_endpoint self._token_endpoint = token_endpoint @@ -52,9 +46,7 @@ def authorization_endpoints(self): def get_authorization_endpoints(self): if self.authorization_endpoints is not None: return self.authorization_endpoints - resp = _requests.get( - url=self._discovery_url, - ) + resp = _requests.get(url=self._discovery_url,) response_body = resp.json() @@ -62,17 +54,18 @@ def get_authorization_endpoints(self): token_endpoint = response_body[_token_endpoint_key] if authorization_endpoint is None: - raise ValueError('Unable to discover authorization endpoint') + raise ValueError("Unable to discover authorization endpoint") if token_endpoint is None: - raise ValueError('Unable to discover token endpoint') + raise ValueError("Unable to discover token endpoint") if authorization_endpoint.startswith("/"): - authorization_endpoint= _requests.compat.urljoin(self._discovery_url, authorization_endpoint) + authorization_endpoint = _requests.compat.urljoin(self._discovery_url, authorization_endpoint) if token_endpoint.startswith("/"): token_endpoint = _requests.compat.urljoin(self._discovery_url, token_endpoint) - self._authorization_endpoints = AuthorizationEndpoints(auth_endpoint=authorization_endpoint, - token_endpoint=token_endpoint) + self._authorization_endpoints = AuthorizationEndpoints( + auth_endpoint=authorization_endpoint, token_endpoint=token_endpoint + ) return self.authorization_endpoints diff --git a/flytekit/clis/flyte_cli/main.py b/flytekit/clis/flyte_cli/main.py index 9d4ac1c227..2f8da2bcd9 100644 --- a/flytekit/clis/flyte_cli/main.py +++ b/flytekit/clis/flyte_cli/main.py @@ -2,43 +2,51 @@ import importlib as _importlib import os as _os -import sys as _sys import stat as _stat +import sys as _sys import click as _click +import requests as _requests import six as _six - -from flyteidl.core import literals_pb2 as _literals_pb2, identifier_pb2 as _identifier_pb2 -from flyteidl.admin import launch_plan_pb2 as _launch_plan_pb2, workflow_pb2 as _workflow_pb2, task_pb2 as _task_pb2 +from flyteidl.admin import launch_plan_pb2 as _launch_plan_pb2 +from flyteidl.admin import task_pb2 as _task_pb2 +from flyteidl.admin import workflow_pb2 as _workflow_pb2 +from flyteidl.core import identifier_pb2 as _identifier_pb2 +from flyteidl.core import literals_pb2 as _literals_pb2 from flytekit import __version__ from flytekit.clients import friendly as _friendly_client -from flytekit.clis.helpers import construct_literal_map_from_variable_map as _construct_literal_map_from_variable_map, \ - construct_literal_map_from_parameter_map as _construct_literal_map_from_parameter_map, \ - parse_args_into_dict as _parse_args_into_dict -from flytekit.common import utils as _utils, launch_plan as _launch_plan_common, \ - workflow_execution as _workflow_execution_common +from flytekit.clis.helpers import construct_literal_map_from_parameter_map as _construct_literal_map_from_parameter_map +from flytekit.clis.helpers import construct_literal_map_from_variable_map as _construct_literal_map_from_variable_map +from flytekit.clis.helpers import parse_args_into_dict as _parse_args_into_dict +from flytekit.common import launch_plan as _launch_plan_common +from flytekit.common import utils as _utils +from flytekit.common import workflow_execution as _workflow_execution_common from flytekit.common.core import identifier as _identifier +from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.tasks import task as _tasks_common from flytekit.common.types import helpers as _type_helpers from flytekit.common.utils import load_proto_from_file as _load_proto_from_file from flytekit.configuration import platform as _platform_config from flytekit.configuration import set_flyte_config_file from flytekit.interfaces.data import data_proxy as _data_proxy -from flytekit.models import common as _common_models, filters as _filters, launch_plan as _launch_plan, literals as \ - _literals, named_entity as _named_entity +from flytekit.models import common as _common_models +from flytekit.models import filters as _filters +from flytekit.models import launch_plan as _launch_plan +from flytekit.models import literals as _literals +from flytekit.models import named_entity as _named_entity from flytekit.models.admin import common as _admin_common -from flytekit.models.core import execution as _core_execution_models, identifier as _core_identifier -from flytekit.models.execution import ExecutionSpec as _ExecutionSpec, ExecutionMetadata as _ExecutionMetadata -from flytekit.models.matchable_resource import ClusterResourceAttributes as _ClusterResourceAttributes,\ - ExecutionQueueAttributes as _ExecutionQueueAttributes, ExecutionClusterLabel as _ExecutionClusterLabel,\ - MatchingAttributes as _MatchingAttributes +from flytekit.models.core import execution as _core_execution_models +from flytekit.models.core import identifier as _core_identifier +from flytekit.models.execution import ExecutionMetadata as _ExecutionMetadata +from flytekit.models.execution import ExecutionSpec as _ExecutionSpec +from flytekit.models.matchable_resource import ClusterResourceAttributes as _ClusterResourceAttributes +from flytekit.models.matchable_resource import ExecutionClusterLabel as _ExecutionClusterLabel +from flytekit.models.matchable_resource import ExecutionQueueAttributes as _ExecutionQueueAttributes +from flytekit.models.matchable_resource import MatchingAttributes as _MatchingAttributes from flytekit.models.project import Project as _Project from flytekit.models.schedule import Schedule as _Schedule -from flytekit.common.exceptions import user as _user_exceptions - -import requests as _requests try: # Python 3 import urllib.parse as _urlparse except ImportError: # Python 2 @@ -69,11 +77,14 @@ def _get_config_file_path(): def _detect_default_config_file(): config_file = _get_config_file_path() if _get_user_filepath_home() and _os.path.exists(config_file): - _click.secho("Using default config file at {}".format(_tt(config_file)), fg='blue') + _click.secho("Using default config file at {}".format(_tt(config_file)), fg="blue") set_flyte_config_file(config_file_path=config_file) else: - _click.secho("""Config file not found at default location, relying on environment variables instead. - To setup your config file run 'flyte-cli setup-config'""", fg='blue') + _click.secho( + """Config file not found at default location, relying on environment variables instead. + To setup your config file run 'flyte-cli setup-config'""", + fg="blue", + ) # Run this as the module is loading to pick up settings that click can then use when constructing the commands @@ -90,9 +101,9 @@ def _get_io_string(literal_map, verbose=False): if value_dict: return "\n" + "\n".join( "{:30}: {}".format( - k, - _prefix_lines("{:30} ".format(""), v.verbose_string() if verbose else v.short_string()) - ) for k, v in _six.iteritems(value_dict) + k, _prefix_lines("{:30} ".format(""), v.verbose_string() if verbose else v.short_string(),), + ) + for k, v in _six.iteritems(value_dict) ) else: return "(None)" @@ -112,7 +123,7 @@ def _fetch_and_stringify_literal_map(path, verbose=False): _utils.load_proto_from_file(_literals_pb2.LiteralMap, fname) ) return _get_io_string(literal_map, verbose=verbose) - except: + except Exception: return "Failed to pull data from {}. Do you have permissions?".format(path) @@ -130,30 +141,30 @@ def _secho_workflow_status(status, nl=True): _core_execution_models.WorkflowExecutionPhase.FAILED, _core_execution_models.WorkflowExecutionPhase.ABORTED, _core_execution_models.WorkflowExecutionPhase.FAILING, - _core_execution_models.WorkflowExecutionPhase.TIMED_OUT + _core_execution_models.WorkflowExecutionPhase.TIMED_OUT, } yellow_phases = { _core_execution_models.WorkflowExecutionPhase.QUEUED, - _core_execution_models.WorkflowExecutionPhase.UNDEFINED + _core_execution_models.WorkflowExecutionPhase.UNDEFINED, } green_phases = { _core_execution_models.WorkflowExecutionPhase.SUCCEEDED, - _core_execution_models.WorkflowExecutionPhase.SUCCEEDING + _core_execution_models.WorkflowExecutionPhase.SUCCEEDING, } if status in red_phases: - fg = 'red' + fg = "red" elif status in yellow_phases: - fg = 'yellow' + fg = "yellow" elif status in green_phases: - fg = 'green' + fg = "green" else: - fg = 'blue' + fg = "blue" _click.secho( "{:10} ".format(_tt(_core_execution_models.WorkflowExecutionPhase.enum_to_string(status))), bold=True, fg=fg, - nl=nl + nl=nl, ) @@ -162,29 +173,24 @@ def _secho_node_execution_status(status, nl=True): _core_execution_models.NodeExecutionPhase.FAILING, _core_execution_models.NodeExecutionPhase.FAILED, _core_execution_models.NodeExecutionPhase.ABORTED, - _core_execution_models.NodeExecutionPhase.TIMED_OUT + _core_execution_models.NodeExecutionPhase.TIMED_OUT, } yellow_phases = { _core_execution_models.NodeExecutionPhase.QUEUED, - _core_execution_models.NodeExecutionPhase.UNDEFINED - } - green_phases = { - _core_execution_models.NodeExecutionPhase.SUCCEEDED + _core_execution_models.NodeExecutionPhase.UNDEFINED, } + green_phases = {_core_execution_models.NodeExecutionPhase.SUCCEEDED} if status in red_phases: - fg = 'red' + fg = "red" elif status in yellow_phases: - fg = 'yellow' + fg = "yellow" elif status in green_phases: - fg = 'green' + fg = "green" else: - fg = 'blue' + fg = "blue" _click.secho( - "{:10} ".format(_tt(_core_execution_models.NodeExecutionPhase.enum_to_string(status))), - bold=True, - fg=fg, - nl=nl + "{:10} ".format(_tt(_core_execution_models.NodeExecutionPhase.enum_to_string(status))), bold=True, fg=fg, nl=nl, ) @@ -196,25 +202,20 @@ def _secho_task_execution_status(status, nl=True): yellow_phases = { _core_execution_models.TaskExecutionPhase.QUEUED, _core_execution_models.TaskExecutionPhase.UNDEFINED, - _core_execution_models.TaskExecutionPhase.RUNNING - } - green_phases = { - _core_execution_models.TaskExecutionPhase.SUCCEEDED + _core_execution_models.TaskExecutionPhase.RUNNING, } + green_phases = {_core_execution_models.TaskExecutionPhase.SUCCEEDED} if status in red_phases: - fg = 'red' + fg = "red" elif status in yellow_phases: - fg = 'yellow' + fg = "yellow" elif status in green_phases: - fg = 'green' + fg = "green" else: - fg = 'blue' + fg = "blue" _click.secho( - "{:10} ".format(_tt(_core_execution_models.TaskExecutionPhase.enum_to_string(status))), - bold=True, - fg=fg, - nl=nl + "{:10} ".format(_tt(_core_execution_models.TaskExecutionPhase.enum_to_string(status))), bold=True, fg=fg, nl=nl, ) @@ -226,25 +227,19 @@ def _secho_one_execution(ex, urns_only): _tt(ex.id.name), _tt(ex.spec.launch_plan.name), ), - nl=False + nl=False, ) _secho_workflow_status(ex.closure.phase) else: _click.echo( - "{:100}".format( - _tt(_identifier.WorkflowExecutionIdentifier.promote_from_model(ex.id)) - ), - nl=True + "{:100}".format(_tt(_identifier.WorkflowExecutionIdentifier.promote_from_model(ex.id))), nl=True, ) def _terminate_one_execution(client, urn, cause, shouldPrint=True): if shouldPrint: _click.echo("{:100} {:40}".format(_tt(urn), _tt(cause))) - client.terminate_execution( - _identifier.WorkflowExecutionIdentifier.from_python_std(urn), - cause - ) + client.terminate_execution(_identifier.WorkflowExecutionIdentifier.from_python_std(urn), cause) def _update_one_launch_plan(urn, host, insecure, state): @@ -261,15 +256,11 @@ def _update_one_launch_plan(urn, host, insecure, state): def _render_schedule_expr(lp): sched_expr = "NONE" if lp.spec.entity_metadata.schedule.cron_expression: - sched_expr = "cron({cron_expr})".format( - cron_expr=_tt(lp.spec.entity_metadata.schedule.cron_expression) - ) + sched_expr = "cron({cron_expr})".format(cron_expr=_tt(lp.spec.entity_metadata.schedule.cron_expression)) elif lp.spec.entity_metadata.schedule.rate: sched_expr = "rate({unit}={value})".format( - unit=_tt(_Schedule.FixedRateUnit.enum_to_string( - lp.spec.entity_metadata.schedule.rate.unit - )), - value=_tt(lp.spec.entity_metadata.schedule.rate.value) + unit=_tt(_Schedule.FixedRateUnit.enum_to_string(lp.spec.entity_metadata.schedule.rate.unit)), + value=_tt(lp.spec.entity_metadata.schedule.rate.value), ) return "{:30}".format(sched_expr) @@ -298,212 +289,145 @@ def _render_schedule_expr(lp): _PRINCIPAL_FLAGS = ["-r", "--principal"] _INSECURE_FLAGS = ["-i", "--insecure"] -_project_option = _click.option( - *_PROJECT_FLAGS, - required=True, - help="The project namespace to query." -) +_project_option = _click.option(*_PROJECT_FLAGS, required=True, help="The project namespace to query.") _optional_project_option = _click.option( - *_PROJECT_FLAGS, - required=False, - default=None, - help="[Optional] The project namespace to query." -) -_domain_option = _click.option( - *_DOMAIN_FLAGS, - required=True, - help="The domain namespace to query." + *_PROJECT_FLAGS, required=False, default=None, help="[Optional] The project namespace to query.", ) +_domain_option = _click.option(*_DOMAIN_FLAGS, required=True, help="The domain namespace to query.") _optional_domain_option = _click.option( - *_DOMAIN_FLAGS, - required=False, - default=None, - help="[Optional] The domain namespace to query." -) -_name_option = _click.option( - *_NAME_FLAGS, - required=True, - help="The name to query." + *_DOMAIN_FLAGS, required=False, default=None, help="[Optional] The domain namespace to query.", ) +_name_option = _click.option(*_NAME_FLAGS, required=True, help="The name to query.") _optional_name_option = _click.option( - *_NAME_FLAGS, - required=False, - type=str, - default=None, - help="[Optional] The name to query." -) -_principal_option = _click.option( - *_PRINCIPAL_FLAGS, - required=True, - help="Your team name, or your name" + *_NAME_FLAGS, required=False, type=str, default=None, help="[Optional] The name to query.", ) +_principal_option = _click.option(*_PRINCIPAL_FLAGS, required=True, help="Your team name, or your name") _optional_principal_option = _click.option( - *_PRINCIPAL_FLAGS, - required=False, - type=str, - default=None, - help="[Optional] Your team name, or your name" -) -_insecure_option = _click.option( - *_INSECURE_FLAGS, - is_flag=True, - required=True, - help="Do not use SSL" -) -_urn_option = _click.option( - "-u", "--urn", - required=True, - help="The unique identifier for an entity." + *_PRINCIPAL_FLAGS, required=False, type=str, default=None, help="[Optional] Your team name, or your name", ) +_insecure_option = _click.option(*_INSECURE_FLAGS, is_flag=True, required=True, help="Do not use SSL") +_urn_option = _click.option("-u", "--urn", required=True, help="The unique identifier for an entity.") -_optional_urn_option = _click.option( - "-u", "--urn", - required=False, - help="The unique identifier for an entity." -) +_optional_urn_option = _click.option("-u", "--urn", required=False, help="The unique identifier for an entity.") _host_option = _click.option( *_HOST_FLAGS, required=not bool(_HOST_URL), default=_HOST_URL, help="The URL for the Flyte Admin Service. If you intend for this to be consistent, set the FLYTE_PLATFORM_URL " - "environment variable to the desired URL and this will not need to be set." + "environment variable to the desired URL and this will not need to be set.", ) _token_option = _click.option( - "-t", "--token", + "-t", + "--token", required=False, default="", type=str, - help="Pagination token from which to start listing in the list of results." + help="Pagination token from which to start listing in the list of results.", ) _limit_option = _click.option( - "-l", "--limit", - required=False, - default=100, - type=int, - help="Maximum number of results to return for this call." + "-l", "--limit", required=False, default=100, type=int, help="Maximum number of results to return for this call.", ) _show_all_option = _click.option( - "-a", "--show-all", - is_flag=True, - default=False, - help="Set this flag to page through and list all results." + "-a", "--show-all", is_flag=True, default=False, help="Set this flag to page through and list all results.", ) # TODO: Provide documentation on filter format _filter_option = _click.option( - "-f", "--filter", + "-f", + "--filter", multiple=True, - help="Filter to be applied. Multiple filters can be applied and they will be ANDed together." + help="Filter to be applied. Multiple filters can be applied and they will be ANDed together.", ) _state_choice = _click.option( "--state", type=_click.Choice(["active", "inactive"]), required=True, - help="Whether or not to set schedule as active." + help="Whether or not to set schedule as active.", ) _named_entity_state_choice = _click.option( "--state", type=_click.Choice(["active", "archived"]), required=True, - help="The state change to apply to a named entity" + help="The state change to apply to a named entity", ) _named_entity_description_option = _click.option( - "--description", - required=False, - type=str, - help="Concise description for the entity." + "--description", required=False, type=str, help="Concise description for the entity.", ) _sort_by_option = _click.option( "--sort-by", required=False, - help="Provide an entity type and field to be sorted. i.e. asc(workflow.name) or desc(workflow.name)" + help="Provide an entity type and field to be sorted. i.e. asc(workflow.name) or desc(workflow.name)", ) _show_io_option = _click.option( "--show-io", is_flag=True, default=False, help="Set this flag to view inputs and outputs. Pair with the --verbose flag to get the full textual description" - " inputs and outputs." + " inputs and outputs.", ) _verbose_option = _click.option( - "--verbose", - is_flag=True, - default=False, - help="Set this flag to view the full textual description of all fields." + "--verbose", is_flag=True, default=False, help="Set this flag to view the full textual description of all fields.", ) -_filename_option = _click.option( - '-f', '--filename', - required=True, - help="File path of pb file" -) +_filename_option = _click.option("-f", "--filename", required=True, help="File path of pb file") _idl_class_option = _click.option( - '-p', '--proto_class', + "-p", + "--proto_class", required=True, - help="Dot (.) separated path to Python IDL class. (e.g. flyteidl.core.workflow_closure_pb2.WorkflowClosure)" + help="Dot (.) separated path to Python IDL class. (e.g. flyteidl.core.workflow_closure_pb2.WorkflowClosure)", ) _cause_option = _click.option( - '-c', '--cause', - required=True, - help="The message signaling the cause of the termination of the execution(s)" + "-c", "--cause", required=True, help="The message signaling the cause of the termination of the execution(s)", ) _optional_urns_only_option = _click.option( - '--urns-only', + "--urns-only", is_flag=True, default=False, required=False, - help="[Optional] Set the flag if you want to output the urn(s) only. Setting this will override the verbose flag" + help="[Optional] Set the flag if you want to output the urn(s) only. Setting this will override the verbose flag", ) _project_identifier_option = _click.option( - '-p', '--identifier', - required=True, - type=str, - help="Unique identifier for the project." + "-p", "--identifier", required=True, type=str, help="Unique identifier for the project.", ) _project_name_option = _click.option( - '-n', '--name', - required=True, - type=str, - help="The human-readable name for the project." + "-n", "--name", required=True, type=str, help="The human-readable name for the project.", ) _project_description_option = _click.option( - '-d', '--description', - required=True, - type=str, - help="Concise description for the project." + "-d", "--description", required=True, type=str, help="Concise description for the project.", ) _watch_option = _click.option( - '-w', '--watch', + "-w", + "--watch", is_flag=True, default=False, - help="Set the flag if you want the command to keep watching the execution until its completion" + help="Set the flag if you want the command to keep watching the execution until its completion", ) class _FlyteSubCommand(_click.Command): _PASSABLE_ARGS = { - 'project': _PROJECT_FLAGS[0], - 'domain': _DOMAIN_FLAGS[0], - 'name': _NAME_FLAGS[0], - 'host': _HOST_FLAGS[0], + "project": _PROJECT_FLAGS[0], + "domain": _DOMAIN_FLAGS[0], + "name": _NAME_FLAGS[0], + "host": _HOST_FLAGS[0], } _PASSABLE_FLAGS = { - 'insecure': _INSECURE_FLAGS[0], + "insecure": _INSECURE_FLAGS[0], } def make_context(self, cmd_name, args, parent=None): prefix_args = [] for param in self.params: - if param.name in type(self)._PASSABLE_ARGS and \ - param.name in parent.params and \ - parent.params[param.name] is not None: + if ( + param.name in type(self)._PASSABLE_ARGS + and param.name in parent.params + and parent.params[param.name] is not None + ): prefix_args.extend([type(self)._PASSABLE_ARGS[param.name], _six.text_type(parent.params[param.name])]) # For flags, we don't append the value of the flag, otherwise click will fail to parse - if param.name in type(self)._PASSABLE_FLAGS and \ - param.name in parent.params and \ - parent.params[param.name]: + if param.name in type(self)._PASSABLE_FLAGS and param.name in parent.params and parent.params[param.name]: prefix_args.append(type(self)._PASSABLE_FLAGS[param.name]) # This is where we handle the value read from the flyte-cli config file, if any, for the insecure flag. @@ -523,7 +447,7 @@ def make_context(self, cmd_name, args, parent=None): type=str, default=None, help="[Optional] The host to pass to the sub-command (if applicable). If set again in the sub-command, " - "the sub-command's parameter takes precedence." + "the sub-command's parameter takes precedence.", ) @_click.option( *_PROJECT_FLAGS, @@ -531,7 +455,7 @@ def make_context(self, cmd_name, args, parent=None): type=str, default=None, help="[Optional] The project to pass to the sub-command (if applicable) If set again in the sub-command, " - "the sub-command's parameter takes precedence." + "the sub-command's parameter takes precedence.", ) @_click.option( *_DOMAIN_FLAGS, @@ -539,7 +463,7 @@ def make_context(self, cmd_name, args, parent=None): type=str, default=None, help="[Optional] The domain to pass to the sub-command (if applicable) If set again in the sub-command, " - "the sub-command's parameter takes precedence." + "the sub-command's parameter takes precedence.", ) @_click.option( *_NAME_FLAGS, @@ -547,7 +471,7 @@ def make_context(self, cmd_name, args, parent=None): type=str, default=None, help="[Optional] The name to pass to the sub-command (if applicable) If set again in the sub-command, " - "the sub-command's parameter takes precedence." + "the sub-command's parameter takes precedence.", ) @_insecure_option @_click.group("flyte-cli") @@ -566,13 +490,13 @@ def _flyte_cli(ctx, host, project, domain, name, insecure): ######################################################################################################################## -@_flyte_cli.command('parse-proto', cls=_click.Command) +@_flyte_cli.command("parse-proto", cls=_click.Command) @_filename_option @_idl_class_option def parse_proto(filename, proto_class): _welcome_message() - splitted = proto_class.split('.') - idl_module = '.'.join(splitted[:-1]) + splitted = proto_class.split(".") + idl_module = ".".join(splitted[:-1]) idl_obj = splitted[-1] mod = _importlib.import_module(idl_module) idl = getattr(mod, idl_obj) @@ -588,7 +512,8 @@ def parse_proto(filename, proto_class): # ######################################################################################################################## -@_flyte_cli.command('list-task-names', cls=_FlyteSubCommand) + +@_flyte_cli.command("list-task-names", cls=_FlyteSubCommand) @_project_option @_domain_option @_host_option @@ -612,7 +537,7 @@ def list_task_names(project, domain, host, insecure, token, limit, show_all, sor domain, limit=limit, token=token, - sort_by=_admin_common.Sort.from_python_std(sort_by) if sort_by else None + sort_by=_admin_common.Sort.from_python_std(sort_by) if sort_by else None, ) for t in task_ids: _click.echo("\t{}".format(_tt(t.name))) @@ -627,7 +552,7 @@ def list_task_names(project, domain, host, insecure, token, limit, show_all, sor _click.echo("") -@_flyte_cli.command('list-task-versions', cls=_FlyteSubCommand) +@_flyte_cli.command("list-task-versions", cls=_FlyteSubCommand) @_project_option @_domain_option @_optional_name_option @@ -647,22 +572,18 @@ def list_task_versions(project, domain, name, host, insecure, token, limit, show _welcome_message() client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) - _click.echo("Task Versions Found for {}:{}:{}\n".format(_tt(project), _tt(domain), _tt(name or '*'))) - _click.echo("{:50} {:40}".format('Version', 'Urn')) + _click.echo("Task Versions Found for {}:{}:{}\n".format(_tt(project), _tt(domain), _tt(name or "*"))) + _click.echo("{:50} {:40}".format("Version", "Urn")) while True: task_list, next_token = client.list_tasks_paginated( - _common_models.NamedEntityIdentifier( - project, - domain, - name - ), + _common_models.NamedEntityIdentifier(project, domain, name), limit=limit, token=token, filters=[_filters.Filter.from_python_std(f) for f in filter], - sort_by=_admin_common.Sort.from_python_std(sort_by) if sort_by else None + sort_by=_admin_common.Sort.from_python_std(sort_by) if sort_by else None, ) for t in task_list: - _click.echo("{:50} {:40}".format(_tt(t.id.version), _tt(_identifier.Identifier.promote_from_model(t.id)))) + _click.echo("{:50} {:40}".format(_tt(t.id.version), _tt(_identifier.Identifier.promote_from_model(t.id)),)) if show_all is not True: if next_token: @@ -674,7 +595,7 @@ def list_task_versions(project, domain, name, host, insecure, token, limit, show _click.echo("") -@_flyte_cli.command('get-task', cls=_FlyteSubCommand) +@_flyte_cli.command("get-task", cls=_FlyteSubCommand) @_urn_option @_host_option @_insecure_option @@ -690,14 +611,14 @@ def get_task(urn, host, insecure): _click.echo("") -@_flyte_cli.command('launch-task', cls=_FlyteSubCommand) +@_flyte_cli.command("launch-task", cls=_FlyteSubCommand) @_project_option @_domain_option @_optional_name_option @_host_option @_insecure_option @_urn_option -@_click.argument('task_args', nargs=-1, type=_click.UNPROCESSED) +@_click.argument("task_args", nargs=-1, type=_click.UNPROCESSED) def launch_task(project, domain, name, host, insecure, urn, task_args): """ Kick off a single task execution. Note that the {project, domain, name} specified in the command line @@ -729,7 +650,7 @@ def launch_task(project, domain, name, host, insecure, urn, task_args): # TODO: Implement label overrides # TODO: Implement annotation overrides execution = task.launch(project, domain, inputs=inputs, name=name) - _click.secho("Launched execution: {}".format(_tt(execution.id)), fg='blue') + _click.secho("Launched execution: {}".format(_tt(execution.id)), fg="blue") _click.echo("") @@ -739,7 +660,8 @@ def launch_task(project, domain, name, host, insecure, urn, task_args): # ######################################################################################################################## -@_flyte_cli.command('list-workflow-names', cls=_FlyteSubCommand) + +@_flyte_cli.command("list-workflow-names", cls=_FlyteSubCommand) @_project_option @_domain_option @_host_option @@ -762,7 +684,7 @@ def list_workflow_names(project, domain, host, insecure, token, limit, show_all, domain, limit=limit, token=token, - sort_by=_admin_common.Sort.from_python_std(sort_by) if sort_by else None + sort_by=_admin_common.Sort.from_python_std(sort_by) if sort_by else None, ) for i in wf_ids: _click.echo("\t{}".format(_tt(i.name))) @@ -777,7 +699,7 @@ def list_workflow_names(project, domain, host, insecure, token, limit, show_all, _click.echo("") -@_flyte_cli.command('list-workflow-versions', cls=_FlyteSubCommand) +@_flyte_cli.command("list-workflow-versions", cls=_FlyteSubCommand) @_project_option @_domain_option @_optional_name_option @@ -797,22 +719,18 @@ def list_workflow_versions(project, domain, name, host, insecure, token, limit, _welcome_message() client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) - _click.echo("Workflow Versions Found for {}:{}:{}\n".format(_tt(project), _tt(domain), _tt(name or '*'))) - _click.echo("{:50} {:40}".format('Version', 'Urn')) + _click.echo("Workflow Versions Found for {}:{}:{}\n".format(_tt(project), _tt(domain), _tt(name or "*"))) + _click.echo("{:50} {:40}".format("Version", "Urn")) while True: wf_list, next_token = client.list_workflows_paginated( - _common_models.NamedEntityIdentifier( - project, - domain, - name - ), + _common_models.NamedEntityIdentifier(project, domain, name), limit=limit, token=token, filters=[_filters.Filter.from_python_std(f) for f in filter], - sort_by=_admin_common.Sort.from_python_std(sort_by) if sort_by else None + sort_by=_admin_common.Sort.from_python_std(sort_by) if sort_by else None, ) for w in wf_list: - _click.echo("{:50} {:40}".format(_tt(w.id.version), _tt(_identifier.Identifier.promote_from_model(w.id)))) + _click.echo("{:50} {:40}".format(_tt(w.id.version), _tt(_identifier.Identifier.promote_from_model(w.id)),)) if show_all is not True: if next_token: @@ -824,7 +742,7 @@ def list_workflow_versions(project, domain, name, host, insecure, token, limit, _click.echo("") -@_flyte_cli.command('get-workflow', cls=_FlyteSubCommand) +@_flyte_cli.command("get-workflow", cls=_FlyteSubCommand) @_urn_option @_host_option @_insecure_option @@ -846,7 +764,8 @@ def get_workflow(urn, host, insecure): # ######################################################################################################################## -@_flyte_cli.command('list-launch-plan-names', cls=_FlyteSubCommand) + +@_flyte_cli.command("list-launch-plan-names", cls=_FlyteSubCommand) @_project_option @_domain_option @_host_option @@ -869,7 +788,7 @@ def list_launch_plan_names(project, domain, host, insecure, token, limit, show_a domain, limit=limit, token=token, - sort_by=_admin_common.Sort.from_python_std(sort_by) if sort_by else None + sort_by=_admin_common.Sort.from_python_std(sort_by) if sort_by else None, ) for i in wf_ids: _click.echo("\t{}".format(_tt(i.name))) @@ -884,7 +803,7 @@ def list_launch_plan_names(project, domain, host, insecure, token, limit, show_a _click.echo("") -@_flyte_cli.command('list-active-launch-plans', cls=_FlyteSubCommand) +@_flyte_cli.command("list-active-launch-plans", cls=_FlyteSubCommand) @_project_option @_domain_option @_host_option @@ -902,7 +821,7 @@ def list_active_launch_plans(project, domain, host, insecure, token, limit, show if not urns_only: _welcome_message() _click.echo("Active Launch Plan Found in {}:{}\n".format(_tt(project), _tt(domain))) - _click.echo("{:30} {:50} {:80}".format('Schedule', 'Version', 'Urn')) + _click.echo("{:30} {:50} {:80}".format("Schedule", "Version", "Urn")) client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) @@ -912,14 +831,12 @@ def list_active_launch_plans(project, domain, host, insecure, token, limit, show domain, limit=limit, token=token, - sort_by=_admin_common.Sort.from_python_std(sort_by) if sort_by else None + sort_by=_admin_common.Sort.from_python_std(sort_by) if sort_by else None, ) for lp in active_lps: if urns_only: - _click.echo("{:80}".format( - _tt(_identifier.Identifier.promote_from_model(lp.id)) - )) + _click.echo("{:80}".format(_tt(_identifier.Identifier.promote_from_model(lp.id)))) else: _click.echo( "{:30} {:50} {:80}".format( @@ -942,7 +859,7 @@ def list_active_launch_plans(project, domain, host, insecure, token, limit, show return -@_flyte_cli.command('list-launch-plan-versions', cls=_FlyteSubCommand) +@_flyte_cli.command("list-launch-plan-versions", cls=_FlyteSubCommand) @_project_option @_domain_option @_optional_name_option @@ -954,49 +871,40 @@ def list_active_launch_plans(project, domain, host, insecure, token, limit, show @_filter_option @_sort_by_option @_optional_urns_only_option -def list_launch_plan_versions(project, domain, name, host, insecure, token, limit, show_all, filter, sort_by, - urns_only): +def list_launch_plan_versions( + project, domain, name, host, insecure, token, limit, show_all, filter, sort_by, urns_only, +): """ List the versions of all the launch plans under the scope specified by {project, domain}. """ if not urns_only: _welcome_message() _click.echo("Launch Plan Versions Found for {}:{}:{}\n".format(_tt(project), _tt(domain), _tt(name))) - _click.echo("{:50} {:80} {:30} {:15}".format('Version', 'Urn', "Schedule", "Schedule State")) + _click.echo("{:50} {:80} {:30} {:15}".format("Version", "Urn", "Schedule", "Schedule State")) client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) while True: lp_list, next_token = client.list_launch_plans_paginated( - _common_models.NamedEntityIdentifier( - project, - domain, - name - ), + _common_models.NamedEntityIdentifier(project, domain, name), limit=limit, token=token, filters=[_filters.Filter.from_python_std(f) for f in filter], - sort_by=_admin_common.Sort.from_python_std(sort_by) if sort_by else None + sort_by=_admin_common.Sort.from_python_std(sort_by) if sort_by else None, ) for l in lp_list: if urns_only: _click.echo(_tt(_identifier.Identifier.promote_from_model(l.id))) else: _click.echo( - "{:50} {:80} ".format( - _tt(l.id.version), - _tt(_identifier.Identifier.promote_from_model(l.id)) - ), - nl=False + "{:50} {:80} ".format(_tt(l.id.version), _tt(_identifier.Identifier.promote_from_model(l.id)),), + nl=False, ) if l.spec.entity_metadata.schedule.cron_expression or l.spec.entity_metadata.schedule.rate: - _click.echo( - "{:30} ".format(_render_schedule_expr(l)), - nl=False - ) + _click.echo("{:30} ".format(_render_schedule_expr(l)), nl=False) _click.secho( _launch_plan.LaunchPlanState.enum_to_string(l.closure.state), - fg="green" if l.closure.state == _launch_plan.LaunchPlanState.ACTIVE else None + fg="green" if l.closure.state == _launch_plan.LaunchPlanState.ACTIVE else None, ) else: _click.echo() @@ -1012,7 +920,7 @@ def list_launch_plan_versions(project, domain, name, host, insecure, token, limi _click.echo("") -@_flyte_cli.command('get-launch-plan', cls=_FlyteSubCommand) +@_flyte_cli.command("get-launch-plan", cls=_FlyteSubCommand) @_urn_option @_host_option @_insecure_option @@ -1028,7 +936,7 @@ def get_launch_plan(urn, host, insecure): _click.echo("") -@_flyte_cli.command('get-active-launch-plan', cls=_FlyteSubCommand) +@_flyte_cli.command("get-active-launch-plan", cls=_FlyteSubCommand) @_project_option @_domain_option @_name_option @@ -1041,19 +949,13 @@ def get_active_launch_plan(project, domain, name, host, insecure): _welcome_message() client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) - lp = client.get_active_launch_plan( - _common_models.NamedEntityIdentifier( - project, - domain, - name - ) - ) + lp = client.get_active_launch_plan(_common_models.NamedEntityIdentifier(project, domain, name)) _click.echo("Active Launch Plan for {}:{}:{}\n".format(_tt(project), _tt(domain), _tt(name))) _click.echo(lp) _click.echo("") -@_flyte_cli.command('update-launch-plan', cls=_FlyteSubCommand) +@_flyte_cli.command("update-launch-plan", cls=_FlyteSubCommand) @_state_choice @_host_option @_insecure_option @@ -1077,7 +979,8 @@ def update_launch_plan(state, host, insecure, urn=None): else: _update_one_launch_plan(urn=urn, host=host, insecure=insecure, state=state) -@_flyte_cli.command('execute-launch-plan', cls=_FlyteSubCommand) + +@_flyte_cli.command("execute-launch-plan", cls=_FlyteSubCommand) @_project_option @_domain_option @_optional_name_option @@ -1087,7 +990,7 @@ def update_launch_plan(state, host, insecure, urn=None): @_principal_option @_verbose_option @_watch_option -@_click.argument('lp_args', nargs=-1, type=_click.UNPROCESSED) +@_click.argument("lp_args", nargs=-1, type=_click.UNPROCESSED) def execute_launch_plan(project, domain, name, host, insecure, urn, principal, verbose, watch, lp_args): """ Kick off a launch plan. Note that the {project, domain, name} specified in the command line @@ -1114,7 +1017,7 @@ def execute_launch_plan(project, domain, name, host, insecure, urn, principal, v # TODO: Implement label overrides # TODO: Implement annotation overrides execution = lp.launch_with_literals(project, domain, inputs, name=name) - _click.secho("Launched execution: {}".format(_tt(execution.id)), fg='blue') + _click.secho("Launched execution: {}".format(_tt(execution.id)), fg="blue") _click.echo("") if watch is True: @@ -1128,7 +1031,7 @@ def execute_launch_plan(project, domain, name, host, insecure, urn, principal, v ######################################################################################################################## -@_flyte_cli.command('watch-execution', cls=_FlyteSubCommand) +@_flyte_cli.command("watch-execution", cls=_FlyteSubCommand) @_host_option @_insecure_option @_urn_option @@ -1150,7 +1053,7 @@ def watch_execution(host, insecure, urn): execution.wait_for_completion() -@_flyte_cli.command('relaunch-execution', cls=_FlyteSubCommand) +@_flyte_cli.command("relaunch-execution", cls=_FlyteSubCommand) @_optional_project_option @_optional_domain_option @_optional_name_option @@ -1159,7 +1062,7 @@ def watch_execution(host, insecure, urn): @_urn_option @_optional_principal_option @_verbose_option -@_click.argument('lp_args', nargs=-1, type=_click.UNPROCESSED) +@_click.argument("lp_args", nargs=-1, type=_click.UNPROCESSED) def relaunch_execution(project, domain, name, host, insecure, urn, principal, verbose, lp_args): """ Relaunch a launch plan. @@ -1206,11 +1109,7 @@ def relaunch_execution(project, domain, name, host, insecure, urn, principal, ve parsed_text_args = _parse_args_into_dict(lp_args) new_inputs = _construct_literal_map_from_variable_map(variable_map, parsed_text_args) if len(new_inputs.literals) > 0: - _click.secho( - "\tNew Inputs: {}\n".format( - _prefix_lines("\t\t", _get_io_string(new_inputs, verbose=verbose)) - ) - ) + _click.secho("\tNew Inputs: {}\n".format(_prefix_lines("\t\t", _get_io_string(new_inputs, verbose=verbose)))) # Construct new inputs from existing execution inputs and new inputs inputs_dict = {} @@ -1232,11 +1131,11 @@ def relaunch_execution(project, domain, name, host, insecure, urn, principal, ve ex_spec = _ExecutionSpec(launch_plan=lp_model.id, inputs=inputs, metadata=metadata) execution_identifier = client.create_execution(project=project, domain=domain, name=name, execution_spec=ex_spec) execution_identifier = _identifier.WorkflowExecutionIdentifier.promote_from_model(execution_identifier) - _click.secho("Launched execution: {}".format(execution_identifier), fg='blue') + _click.secho("Launched execution: {}".format(execution_identifier), fg="blue") _click.echo("") -@_flyte_cli.command('terminate-execution', cls=_FlyteSubCommand) +@_flyte_cli.command("terminate-execution", cls=_FlyteSubCommand) @_host_option @_insecure_option @_cause_option @@ -1286,7 +1185,7 @@ def terminate_execution(host, insecure, cause, urn=None): _terminate_one_execution(client, urn, cause) -@_flyte_cli.command('list-executions', cls=_FlyteSubCommand) +@_flyte_cli.command("list-executions", cls=_FlyteSubCommand) @_project_option @_domain_option @_host_option @@ -1322,7 +1221,7 @@ def list_executions(project, domain, host, insecure, token, limit, show_all, fil limit=limit, token=token, filters=[_filters.Filter.from_python_std(f) for f in filter], - sort_by=_admin_common.Sort.from_python_std(sort_by) if sort_by else None + sort_by=_admin_common.Sort.from_python_std(sort_by) if sort_by else None, ) for ex in exec_ids: _secho_one_execution(ex, urns_only) @@ -1344,9 +1243,11 @@ def _get_io(node_executions, wf_execution, show_io, verbose): if show_io: uris = [ne.input_uri for ne in node_executions] uris.extend([ne.closure.output_uri for ne in node_executions if ne.closure.output_uri is not None]) - if wf_execution is not None and \ - wf_execution.closure.outputs is not None and \ - wf_execution.closure.outputs.uri is not None: + if ( + wf_execution is not None + and wf_execution.closure.outputs is not None + and wf_execution.closure.outputs.uri is not None + ): uris.append(wf_execution.closure.outputs.uri) with _click.progressbar(uris, label="Downloading Inputs and Outputs") as progress_bar_uris: @@ -1358,24 +1259,21 @@ def _get_io(node_executions, wf_execution, show_io, verbose): def _render_workflow_execution(wf_execution, uri_to_message_map, show_io, verbose): _click.echo( "\nExecution {project}:{domain}:{name}\n".format( - project=_tt(wf_execution.id.project), - domain=_tt(wf_execution.id.domain), - name=_tt(wf_execution.id.name) + project=_tt(wf_execution.id.project), domain=_tt(wf_execution.id.domain), name=_tt(wf_execution.id.name), ) ) _click.echo("\t{:15} ".format("State:"), nl=False) _secho_workflow_status(wf_execution.closure.phase) _click.echo( "\t{:15} {}".format( - "Launch Plan:", - _tt(_identifier.Identifier.promote_from_model(wf_execution.spec.launch_plan)) + "Launch Plan:", _tt(_identifier.Identifier.promote_from_model(wf_execution.spec.launch_plan)), ) ) if show_io: _click.secho( "\tInputs: {}\n".format( - _prefix_lines("\t\t", _get_io_string(wf_execution.closure.computed_inputs, verbose=verbose)) + _prefix_lines("\t\t", _get_io_string(wf_execution.closure.computed_inputs, verbose=verbose),) ) ) if wf_execution.closure.outputs is not None: @@ -1384,24 +1282,23 @@ def _render_workflow_execution(wf_execution, uri_to_message_map, show_io, verbos "\tOutputs: {}\n".format( _prefix_lines( "\t\t", - uri_to_message_map.get( - wf_execution.closure.outputs.uri, - wf_execution.closure.outputs.uri - ) + uri_to_message_map.get(wf_execution.closure.outputs.uri, wf_execution.closure.outputs.uri,), ) ) ) elif wf_execution.closure.outputs.values is not None: _click.secho( "\tOutputs: {}\n".format( - _prefix_lines("\t\t", _get_io_string(wf_execution.closure.outputs.values, verbose=verbose)) + _prefix_lines("\t\t", _get_io_string(wf_execution.closure.outputs.values, verbose=verbose),) ) ) else: _click.echo("\t{:15} (None)".format("Outputs:")) if wf_execution.closure.error is not None: - _click.secho(_prefix_lines("\t", _render_error(wf_execution.closure.error)), fg='red', bold=True) + _click.secho( + _prefix_lines("\t", _render_error(wf_execution.closure.error)), fg="red", bold=True, + ) def _render_error(error): @@ -1419,9 +1316,7 @@ def _get_all_task_executions_for_node(client, node_execution_identifier): while True: num_to_fetch = 100 task_execs, next_token = client.list_task_executions_paginated( - node_execution_identifier=node_execution_identifier, - limit=num_to_fetch, - token=token + node_execution_identifier=node_execution_identifier, limit=num_to_fetch, token=token, ) for te in task_execs: fetched_task_execs.append(te) @@ -1440,15 +1335,11 @@ def _get_all_node_executions(client, workflow_execution_identifier=None, task_ex num_to_fetch = 100 if workflow_execution_identifier: node_execs, next_token = client.list_node_executions( - workflow_execution_identifier=workflow_execution_identifier, - limit=num_to_fetch, - token=token + workflow_execution_identifier=workflow_execution_identifier, limit=num_to_fetch, token=token, ) else: node_execs, next_token = client.list_node_executions_for_task_paginated( - task_execution_identifier=task_execution_identifier, - limit=num_to_fetch, - token=token, + task_execution_identifier=task_execution_identifier, limit=num_to_fetch, token=token, ) all_node_execs.extend(node_execs) if not next_token: @@ -1468,7 +1359,7 @@ def _render_node_executions(client, node_execs, show_io, verbose, host, insecure _click.echo("\n\tNode Executions:\n") for ne in sorted(node_execs, key=lambda x: x.closure.started_at): - if ne.id.node_id == 'start-node': + if ne.id.node_id == "start-node": continue _click.echo("\t\tID: {}\n".format(_tt(ne.id.node_id))) _click.echo("\t\t\t{:15} ".format("Status:"), nl=False) @@ -1477,8 +1368,7 @@ def _render_node_executions(client, node_execs, show_io, verbose, host, insecure _click.echo("\t\t\t{:15} {:60} ".format("Duration:", _tt(ne.closure.duration))) _click.echo( "\t\t\t{:15} {}".format( - "Input:", - _prefix_lines("\t\t\t{:15} ".format(""), uri_to_message_map.get(ne.input_uri, ne.input_uri)) + "Input:", _prefix_lines("\t\t\t{:15} ".format(""), uri_to_message_map.get(ne.input_uri, ne.input_uri),), ) ) if ne.closure.output_uri: @@ -1486,19 +1376,13 @@ def _render_node_executions(client, node_execs, show_io, verbose, host, insecure "\t\t\t{:15} {}".format( "Output:", _prefix_lines( - "\t\t\t{:15} ".format(""), - uri_to_message_map.get(ne.closure.output_uri, ne.closure.output_uri) - ) + "\t\t\t{:15} ".format(""), uri_to_message_map.get(ne.closure.output_uri, ne.closure.output_uri), + ), ) ) if ne.closure.error is not None: _click.secho( - _prefix_lines( - "\t\t\t", - _render_error(ne.closure.error) - ), - bold=True, - fg='red' + _prefix_lines("\t\t\t", _render_error(ne.closure.error)), bold=True, fg="red", ) task_executions = node_executions_to_task_executions.get(ne.id, []) @@ -1522,12 +1406,7 @@ def _render_node_executions(client, node_execs, show_io, verbose, host, insecure if te.closure.error is not None: _click.secho( - _prefix_lines( - "\t\t\t\t\t", - _render_error(te.closure.error) - ), - bold=True, - fg='red' + _prefix_lines("\t\t\t\t\t", _render_error(te.closure.error)), bold=True, fg="red", ) if te.is_parent: @@ -1537,8 +1416,8 @@ def _render_node_executions(client, node_execs, show_io, verbose, host, insecure "flyte-cli get-child-executions -h {host}{insecure} -u {urn}".format( host=host, urn=_tt(_identifier.TaskExecutionIdentifier.promote_from_model(te.id)), - insecure=" --insecure" if insecure else "" - ) + insecure=" --insecure" if insecure else "", + ), ) ) _click.echo() @@ -1546,7 +1425,7 @@ def _render_node_executions(client, node_execs, show_io, verbose, host, insecure _click.echo() -@_flyte_cli.command('get-execution', cls=_FlyteSubCommand) +@_flyte_cli.command("get-execution", cls=_FlyteSubCommand) @_urn_option @_host_option @_insecure_option @@ -1564,7 +1443,7 @@ def get_execution(urn, host, insecure, show_io, verbose): _render_node_executions(client, node_execs, show_io, verbose, host, insecure, wf_execution=e) -@_flyte_cli.command('get-child-executions', cls=_FlyteSubCommand) +@_flyte_cli.command("get-child-executions", cls=_FlyteSubCommand) @_urn_option @_host_option @_insecure_option @@ -1574,13 +1453,12 @@ def get_child_executions(urn, host, insecure, show_io, verbose): _welcome_message() client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) node_execs = _get_all_node_executions( - client, - task_execution_identifier=_identifier.TaskExecutionIdentifier.from_python_std(urn) + client, task_execution_identifier=_identifier.TaskExecutionIdentifier.from_python_std(urn), ) _render_node_executions(client, node_execs, show_io, verbose, host, insecure) -@_flyte_cli.command('register-project', cls=_FlyteSubCommand) +@_flyte_cli.command("register-project", cls=_FlyteSubCommand) @_project_identifier_option @_project_name_option @_project_description_option @@ -1606,12 +1484,13 @@ def _extract_pair(identifier_file, object_file): resource_map = { _identifier_pb2.LAUNCH_PLAN: _launch_plan_pb2.LaunchPlanSpec, _identifier_pb2.WORKFLOW: _workflow_pb2.WorkflowSpec, - _identifier_pb2.TASK: _task_pb2.TaskSpec + _identifier_pb2.TASK: _task_pb2.TaskSpec, } id = _load_proto_from_file(_identifier_pb2.Identifier, identifier_file) - if not id.resource_type in resource_map: - raise _user_exceptions.FlyteAssertion(f"Resource type found in identifier {id.resource_type} invalid, must be launch plan, " - f"task, or workflow") + if id.resource_type not in resource_map: + raise _user_exceptions.FlyteAssertion( + f"Resource type found in identifier {id.resource_type} invalid, must be launch plan, " f"task, or workflow" + ) entity = _load_proto_from_file(resource_map[id.resource_type], object_file) return id, entity @@ -1635,13 +1514,11 @@ def _extract_files(file_paths): return results -@_flyte_cli.command('register-files', cls=_FlyteSubCommand) +@_flyte_cli.command("register-files", cls=_FlyteSubCommand) @_host_option @_insecure_option @_click.argument( - 'files', - type=_click.Path(exists=True), - nargs=-1, + "files", type=_click.Path(exists=True), nargs=-1, ) def register_files(host, insecure, files): """ @@ -1666,7 +1543,7 @@ def register_files(host, insecure, files): client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) files = list(files) files.sort() - _click.secho("Parsing files...", fg='green', bold=True) + _click.secho("Parsing files...", fg="green", bold=True) for f in files: _click.echo(f" {f}") @@ -1680,16 +1557,18 @@ def register_files(host, insecure, files): elif id.resource_type == _identifier_pb2.WORKFLOW: client.raw.create_workflow(_workflow_pb2.WorkflowCreateRequest(id=id, spec=flyte_entity)) else: - raise _user_exceptions.FlyteAssertion(f"Only tasks, launch plans, and workflows can be called with this function, " - f"resource type {id.resource_type} was passed") - _click.secho(f"Registered {id}", fg='green') + raise _user_exceptions.FlyteAssertion( + f"Only tasks, launch plans, and workflows can be called with this function, " + f"resource type {id.resource_type} was passed" + ) + _click.secho(f"Registered {id}", fg="green") except _user_exceptions.FlyteEntityAlreadyExistsException: - _click.secho(f"Skipping because already registered {id}", fg='cyan') + _click.secho(f"Skipping because already registered {id}", fg="cyan") _click.echo(f"Finished scanning {len(flyte_entities_list)} files") -@_flyte_cli.command('update-workflow-meta', cls=_FlyteSubCommand) +@_flyte_cli.command("update-workflow-meta", cls=_FlyteSubCommand) @_named_entity_description_option @_named_entity_state_choice @_host_option @@ -1710,11 +1589,12 @@ def update_workflow_meta(description, state, host, insecure, project, domain, na client.update_named_entity( _core_identifier.ResourceType.WORKFLOW, _named_entity.NamedEntityIdentifier(project, domain, name), - _named_entity.NamedEntityMetadata(description, state)) + _named_entity.NamedEntityMetadata(description, state), + ) _click.echo("Successfully updated workflow") -@_flyte_cli.command('update-task-meta', cls=_FlyteSubCommand) +@_flyte_cli.command("update-task-meta", cls=_FlyteSubCommand) @_named_entity_description_option @_host_option @_insecure_option @@ -1730,11 +1610,12 @@ def update_task_meta(description, host, insecure, project, domain, name): client.update_named_entity( _core_identifier.ResourceType.TASK, _named_entity.NamedEntityIdentifier(project, domain, name), - _named_entity.NamedEntityMetadata(description, _named_entity.NamedEntityState.ACTIVE)) + _named_entity.NamedEntityMetadata(description, _named_entity.NamedEntityState.ACTIVE), + ) _click.echo("Successfully updated task") -@_flyte_cli.command('update-launch-plan-meta', cls=_FlyteSubCommand) +@_flyte_cli.command("update-launch-plan-meta", cls=_FlyteSubCommand) @_named_entity_description_option @_host_option @_insecure_option @@ -1750,17 +1631,18 @@ def update_launch_plan_meta(description, host, insecure, project, domain, name): client.update_named_entity( _core_identifier.ResourceType.LAUNCH_PLAN, _named_entity.NamedEntityIdentifier(project, domain, name), - _named_entity.NamedEntityMetadata(description, _named_entity.NamedEntityState.ACTIVE)) + _named_entity.NamedEntityMetadata(description, _named_entity.NamedEntityState.ACTIVE), + ) _click.echo("Successfully updated launch plan") -@_flyte_cli.command('update-cluster-resource-attributes', cls=_FlyteSubCommand) +@_flyte_cli.command("update-cluster-resource-attributes", cls=_FlyteSubCommand) @_host_option @_insecure_option @_project_option @_domain_option @_optional_name_option -@_click.option('--attributes', type=(str, str), multiple=True) +@_click.option("--attributes", type=(str, str), multiple=True) def update_cluster_resource_attributes(host, insecure, project, domain, name, attributes): """ Sets matchable cluster resource attributes for a project, domain and optionally, workflow name. @@ -1775,20 +1657,20 @@ def update_cluster_resource_attributes(host, insecure, project, domain, name, at matching_attributes = _MatchingAttributes(cluster_resource_attributes=cluster_resource_attributes) if name is not None: - client.update_workflow_attributes( - project, domain, name, matching_attributes + client.update_workflow_attributes(project, domain, name, matching_attributes) + _click.echo( + "Successfully updated cluster resource attributes for project: {}, domain: {}, and workflow: {}".format( + project, domain, name + ) ) - _click.echo("Successfully updated cluster resource attributes for project: {}, domain: {}, and workflow: {}". - format(project, domain, name)) else: - client.update_project_domain_attributes( - project, domain, matching_attributes + client.update_project_domain_attributes(project, domain, matching_attributes) + _click.echo( + "Successfully updated cluster resource attributes for project: {} and domain: {}".format(project, domain) ) - _click.echo("Successfully updated cluster resource attributes for project: {} and domain: {}". - format(project, domain)) -@_flyte_cli.command('update-execution-queue-attributes', cls=_FlyteSubCommand) +@_flyte_cli.command("update-execution-queue-attributes", cls=_FlyteSubCommand) @_host_option @_insecure_option @_project_option @@ -1809,26 +1691,26 @@ def update_execution_queue_attributes(host, insecure, project, domain, name, tag matching_attributes = _MatchingAttributes(execution_queue_attributes=execution_queue_attributes) if name is not None: - client.update_workflow_attributes( - project, domain, name, matching_attributes + client.update_workflow_attributes(project, domain, name, matching_attributes) + _click.echo( + "Successfully updated execution queue attributes for project: {}, domain: {}, and workflow: {}".format( + project, domain, name + ) ) - _click.echo("Successfully updated execution queue attributes for project: {}, domain: {}, and workflow: {}". - format(project, domain, name)) else: - client.update_project_domain_attributes( - project, domain, matching_attributes + client.update_project_domain_attributes(project, domain, matching_attributes) + _click.echo( + "Successfully updated execution queue attributes for project: {} and domain: {}".format(project, domain) ) - _click.echo("Successfully updated execution queue attributes for project: {} and domain: {}". - format(project, domain)) -@_flyte_cli.command('update-execution-cluster-label', cls=_FlyteSubCommand) +@_flyte_cli.command("update-execution-cluster-label", cls=_FlyteSubCommand) @_host_option @_insecure_option @_project_option @_domain_option @_optional_name_option -@_click.option("--value", help="Cluster label for which to schedule matching executions") +@_click.option("--value", help="Cluster label for which to schedule matching executions") def update_execution_cluster_label(host, insecure, project, domain, name, value): """ Label value to determine where an execution's task will be run for tasks belonging to a project, domain and @@ -1843,20 +1725,20 @@ def update_execution_cluster_label(host, insecure, project, domain, name, value) matching_attributes = _MatchingAttributes(execution_cluster_label=execution_cluster_label) if name is not None: - client.update_workflow_attributes( - project, domain, name, matching_attributes + client.update_workflow_attributes(project, domain, name, matching_attributes) + _click.echo( + "Successfully updated execution cluster label for project: {}, domain: {}, and workflow: {}".format( + project, domain, name + ) ) - _click.echo("Successfully updated execution cluster label for project: {}, domain: {}, and workflow: {}". - format(project, domain, name)) else: - client.update_project_domain_attributes( - project, domain, matching_attributes + client.update_project_domain_attributes(project, domain, matching_attributes) + _click.echo( + "Successfully updated execution cluster label for project: {} and domain: {}".format(project, domain) ) - _click.echo("Successfully updated execution cluster label for project: {} and domain: {}". - format(project, domain)) -@_flyte_cli.command('setup-config', cls=_click.Command) +@_flyte_cli.command("setup-config", cls=_click.Command) @_host_option @_insecure_option def setup_config(host, insecure): @@ -1867,13 +1749,15 @@ def setup_config(host, insecure): _welcome_message() config_file = _get_config_file_path() if _get_user_filepath_home() and _os.path.exists(config_file): - _click.secho("Config file already exists at {}".format(_tt(config_file)), fg='blue') + _click.secho("Config file already exists at {}".format(_tt(config_file)), fg="blue") return # Before creating check that the directory exists and create if not config_dir = _os.path.join(_get_user_filepath_home(), _default_config_file_dir) if not _os.path.isdir(config_dir): - _click.secho("Creating default Flyte configuration directory at {}".format(_tt(config_dir)), fg='blue') + _click.secho( + "Creating default Flyte configuration directory at {}".format(_tt(config_dir)), fg="blue", + ) _os.mkdir(config_dir) full_host = "http://{}".format(host) if insecure else "https://{}".format(host) @@ -1899,7 +1783,7 @@ def setup_config(host, insecure): f.write("auth_mode=standard") f.write("\n") set_flyte_config_file(config_file_path=config_file) - _click.secho("Wrote default config file to {}".format(_tt(config_file)), fg='blue') + _click.secho("Wrote default config file to {}".format(_tt(config_file)), fg="blue") if __name__ == "__main__": diff --git a/flytekit/clis/helpers.py b/flytekit/clis/helpers.py index 8743bd2fd2..741cf5dbf2 100644 --- a/flytekit/clis/helpers.py +++ b/flytekit/clis/helpers.py @@ -41,8 +41,7 @@ def parse_args_into_dict(input_arguments): :rtype: dict[Text, Text] """ - return {split_arg[0]: split_arg[1] for split_arg in - [input_arg.split('=', 1) for input_arg in input_arguments]} + return {split_arg[0]: split_arg[1] for split_arg in [input_arg.split("=", 1) for input_arg in input_arguments]} def construct_literal_map_from_parameter_map(parameter_map, text_args): @@ -65,7 +64,7 @@ def construct_literal_map_from_parameter_map(parameter_map, text_args): if var_name in text_args and text_args[var_name] is not None: inputs[var_name] = sdk_type.from_string(text_args[var_name]) else: - raise Exception('Missing required parameter {}'.format(var_name)) + raise Exception("Missing required parameter {}".format(var_name)) else: if var_name in text_args and text_args[var_name] is not None: inputs[var_name] = sdk_type.from_string(text_args[var_name]) @@ -81,4 +80,4 @@ def str2bool(str): :param Text str: :rtype: bool """ - return not str.lower() in ['false', '0', 'off', 'no'] + return not str.lower() in ["false", "0", "off", "no"] diff --git a/flytekit/clis/sdk_in_container/basic_auth.py b/flytekit/clis/sdk_in_container/basic_auth.py index 05671752e0..e74d0bd08f 100644 --- a/flytekit/clis/sdk_in_container/basic_auth.py +++ b/flytekit/clis/sdk_in_container/basic_auth.py @@ -6,11 +6,9 @@ import requests as _requests from flytekit.common.exceptions.user import FlyteAuthenticationException as _FlyteAuthenticationException -from flytekit.configuration.creds import ( - CLIENT_CREDENTIALS_SECRET as _CREDENTIALS_SECRET, -) +from flytekit.configuration.creds import CLIENT_CREDENTIALS_SECRET as _CREDENTIALS_SECRET -_utf_8 = 'utf-8' +_utf_8 = "utf-8" def get_secret(): @@ -22,7 +20,7 @@ def get_secret(): secret = _CREDENTIALS_SECRET.get() if secret: return secret - raise _FlyteAuthenticationException('No secret could be found') + raise _FlyteAuthenticationException("No secret could be found") def get_basic_authorization_header(client_id, client_secret): @@ -46,20 +44,20 @@ def get_token(token_endpoint, authorization_header, scope): in seconds """ headers = { - 'Authorization': authorization_header, - 'Cache-Control': 'no-cache', - 'Accept': 'application/json', - 'Content-Type': 'application/x-www-form-urlencoded' + "Authorization": authorization_header, + "Cache-Control": "no-cache", + "Accept": "application/json", + "Content-Type": "application/x-www-form-urlencoded", } body = { - 'grant_type': 'client_credentials', + "grant_type": "client_credentials", } if scope is not None: - body['scope'] = scope + body["scope"] = scope response = _requests.post(token_endpoint, data=body, headers=headers) if response.status_code != 200: _logging.error("Non-200 ({}) received from IDP: {}".format(response.status_code, response.text)) - raise _FlyteAuthenticationException('Non-200 received from IDP') + raise _FlyteAuthenticationException("Non-200 received from IDP") response = response.json() - return response['access_token'], response['expires_in'] + return response["access_token"], response["expires_in"] diff --git a/flytekit/clis/sdk_in_container/constants.py b/flytekit/clis/sdk_in_container/constants.py index 44313cff7f..34dc4460bd 100644 --- a/flytekit/clis/sdk_in_container/constants.py +++ b/flytekit/clis/sdk_in_container/constants.py @@ -1,6 +1,6 @@ -CTX_PROJECT = 'project' -CTX_DOMAIN = 'domain' -CTX_VERSION = 'version' -CTX_TEST = 'test' -CTX_PACKAGES = 'pkgs' -CTX_NOTIFICATIONS = 'notifications' +CTX_PROJECT = "project" +CTX_DOMAIN = "domain" +CTX_VERSION = "version" +CTX_TEST = "test" +CTX_PACKAGES = "pkgs" +CTX_NOTIFICATIONS = "notifications" diff --git a/flytekit/clis/sdk_in_container/launch_plan.py b/flytekit/clis/sdk_in_container/launch_plan.py index de5bb7d497..bb8e5e0239 100644 --- a/flytekit/clis/sdk_in_container/launch_plan.py +++ b/flytekit/clis/sdk_in_container/launch_plan.py @@ -1,15 +1,16 @@ from __future__ import absolute_import +import logging as _logging + import click import six as _six -import logging as _logging from flytekit.clis.helpers import construct_literal_map_from_parameter_map as _construct_literal_map_from_parameter_map from flytekit.clis.sdk_in_container import constants as _constants from flytekit.common import utils as _utils from flytekit.common.launch_plan import SdkLaunchPlan as _SdkLaunchPlan -from flytekit.configuration.internal import look_up_version_from_image_tag as _look_up_version_from_image_tag, \ - IMAGE as _IMAGE +from flytekit.configuration.internal import IMAGE as _IMAGE +from flytekit.configuration.internal import look_up_version_from_image_tag as _look_up_version_from_image_tag from flytekit.models import launch_plan as _launch_plan_model from flytekit.models.core import identifier as _identifier from flytekit.tools.module_loader import iterate_registerable_entities_in_order @@ -30,13 +31,13 @@ def list_commands(self, ctx): pkgs = ctx.obj[_constants.CTX_PACKAGES] # Discover all launch plans by loading the modules for m, k, lp in iterate_registerable_entities_in_order( - pkgs, include_entities={_SdkLaunchPlan}, - detect_unreferenced_entities=False): + pkgs, include_entities={_SdkLaunchPlan}, detect_unreferenced_entities=False + ): safe_name = _utils.fqdn(m.__name__, k, entity_type=lp.resource_type) commands.append(safe_name) lps[safe_name] = lp - ctx.obj['lps'] = lps + ctx.obj["lps"] = lps commands.sort() return commands @@ -48,24 +49,25 @@ def get_command(self, ctx, lp_argument): launch_plan = None pkgs = ctx.obj[_constants.CTX_PACKAGES] - if 'lps' in ctx.obj: - launch_plan = ctx.obj['lps'][lp_argument] + if "lps" in ctx.obj: + launch_plan = ctx.obj["lps"][lp_argument] else: for m, k, lp in iterate_registerable_entities_in_order( - pkgs, include_entities={_SdkLaunchPlan}, detect_unreferenced_entities=False): + pkgs, include_entities={_SdkLaunchPlan}, detect_unreferenced_entities=False, + ): safe_name = _utils.fqdn(m.__name__, k, entity_type=lp.resource_type) if lp_argument == safe_name: launch_plan = lp if launch_plan is None: - raise Exception('Could not load launch plan {}'.format(lp_argument)) + raise Exception("Could not load launch plan {}".format(lp_argument)) launch_plan._id = _identifier.Identifier( _identifier.ResourceType.LAUNCH_PLAN, ctx.obj[_constants.CTX_PROJECT], ctx.obj[_constants.CTX_DOMAIN], lp_argument, - ctx.obj[_constants.CTX_VERSION] + ctx.obj[_constants.CTX_VERSION], ) return self._get_command(ctx, launch_plan, lp_argument) @@ -79,7 +81,6 @@ def _get_command(self, ctx, lp, cmd_name): class LaunchPlanExecuteGroup(LaunchPlanAbstractGroup): - def _get_command(self, ctx, lp, cmd_name): """ This function returns the function that click will actually use to execute a specific launch plan. It also @@ -100,10 +101,11 @@ def _execute_lp(**kwargs): ctx.obj[_constants.CTX_PROJECT], ctx.obj[_constants.CTX_DOMAIN], literal_inputs=inputs, - notification_overrides=ctx.obj.get(_constants.CTX_NOTIFICATIONS, None) + notification_overrides=ctx.obj.get(_constants.CTX_NOTIFICATIONS, None), + ) + click.echo( + click.style("Workflow scheduled, execution_id={}".format(_six.text_type(execution.id)), fg="blue",) ) - click.echo(click.style("Workflow scheduled, execution_id={}".format( - _six.text_type(execution.id)), fg='blue')) command = click.Command(name=cmd_name, callback=_execute_lp) @@ -112,15 +114,13 @@ def _execute_lp(**kwargs): param = lp.default_inputs.parameters[var_name] # TODO: Figure out how to better handle the fact that we want strings to parse, # but we probably shouldn't have click say that that's the type on the CLI. - help_msg = '{} Type: {}'.format( - _six.text_type(param.var.description), - _six.text_type(param.var.type) + help_msg = "{} Type: {}".format( + _six.text_type(param.var.description), _six.text_type(param.var.type) ).strip() if param.required: # If it's a required input, add the required flag - wrapper = click.option('--{}'.format(var_name), required=True, type=_six.text_type, - help=help_msg) + wrapper = click.option("--{}".format(var_name), required=True, type=_six.text_type, help=help_msg,) else: # If it's not a required input, it should have a default # Use to_python_std so that the text of the default ends up being parseable, if not, the click @@ -128,16 +128,19 @@ def _execute_lp(**kwargs): # we'd get '11' and then we'd need annoying logic to differentiate between the default text # and user text. default = param.default.to_python_std() - wrapper = click.option('--{}'.format(var_name), default='{}'.format(_six.text_type(default)), - type=_six.text_type, - help='{}. Default: {}'.format(help_msg, _six.text_type(default))) + wrapper = click.option( + "--{}".format(var_name), + default="{}".format(_six.text_type(default)), + type=_six.text_type, + help="{}. Default: {}".format(help_msg, _six.text_type(default)), + ) command = wrapper(command) return command -@click.group('lp') +@click.group("lp") @click.pass_context def launch_plans(ctx): """ @@ -146,7 +149,7 @@ def launch_plans(ctx): pass -@click.group('execute', cls=LaunchPlanExecuteGroup) +@click.group("execute", cls=LaunchPlanExecuteGroup) @click.pass_context def execute_launch_plan(ctx): """ @@ -169,16 +172,20 @@ def activate_all_impl(project, domain, version, pkgs, ignore_schedules=False): project, domain, _utils.fqdn(m.__name__, k, entity_type=lp.resource_type), - version + version, ) if not (lp.is_scheduled and ignore_schedules): _logging.info(f"Setting active {_utils.fqdn(m.__name__, k, entity_type=lp.resource_type)}") lp.update(_launch_plan_model.LaunchPlanState.ACTIVE) -@click.command('activate-all-schedules') -@click.option('-v', '--version', type=str, help='Version to register tasks with. This is normally parsed from the' - 'image, but you can override here.') +@click.command("activate-all-schedules") +@click.option( + "-v", + "--version", + type=str, + help="Version to register tasks with. This is normally parsed from the" "image, but you can override here.", +) @click.pass_context def activate_all_schedules(ctx, version=None): """ @@ -186,7 +193,9 @@ def activate_all_schedules(ctx, version=None): The behavior of this command is identical to activate-all. """ - click.secho("activate-all-schedules is deprecated, please use activate-all instead.", color="yellow") + click.secho( + "activate-all-schedules is deprecated, please use activate-all instead.", color="yellow", + ) project = ctx.obj[_constants.CTX_PROJECT] domain = ctx.obj[_constants.CTX_DOMAIN] pkgs = ctx.obj[_constants.CTX_PACKAGES] @@ -194,10 +203,16 @@ def activate_all_schedules(ctx, version=None): activate_all_impl(project, domain, version, pkgs) -@click.command('activate-all') -@click.option('-v', '--version', type=str, help='Version to register tasks with. This is normally parsed from the' - 'image, but you can override here.') -@click.option("--ignore-schedules", is_flag=True, help='Activate all except for launch plans with schedules.') +@click.command("activate-all") +@click.option( + "-v", + "--version", + type=str, + help="Version to register tasks with. This is normally parsed from the" "image, but you can override here.", +) +@click.option( + "--ignore-schedules", is_flag=True, help="Activate all except for launch plans with schedules.", +) @click.pass_context def activate_all(ctx, version=None, ignore_schedules=False): """ diff --git a/flytekit/clis/sdk_in_container/pyflyte.py b/flytekit/clis/sdk_in_container/pyflyte.py index a6963496a5..1be298b9c2 100644 --- a/flytekit/clis/sdk_in_container/pyflyte.py +++ b/flytekit/clis/sdk_in_container/pyflyte.py @@ -1,8 +1,8 @@ -from __future__ import absolute_import -from __future__ import print_function +from __future__ import absolute_import, print_function -import os as _os import logging as _logging +import os as _os + import click try: @@ -10,32 +10,49 @@ except ImportError: from pathlib2 import Path # python 2 backport -from flytekit.clis.sdk_in_container.constants import CTX_PACKAGES +from flytekit.clis.sdk_in_container.constants import CTX_DOMAIN, CTX_PACKAGES, CTX_PROJECT, CTX_VERSION +from flytekit.clis.sdk_in_container.launch_plan import launch_plans from flytekit.clis.sdk_in_container.register import register from flytekit.clis.sdk_in_container.serialize import serialize -from flytekit.clis.sdk_in_container.constants import CTX_PROJECT, CTX_DOMAIN, CTX_VERSION -from flytekit.clis.sdk_in_container.launch_plan import launch_plans -from flytekit.configuration import internal as _internal_config, platform as _platform_config, sdk as _sdk_config - +from flytekit.configuration import internal as _internal_config +from flytekit.configuration import platform as _platform_config +from flytekit.configuration import sdk as _sdk_config +from flytekit.configuration import set_flyte_config_file from flytekit.configuration.internal import CONFIGURATION_PATH +from flytekit.configuration.internal import IMAGE as _IMAGE +from flytekit.configuration.internal import look_up_version_from_image_tag as _look_up_version_from_image_tag from flytekit.configuration.platform import URL as _URL -from flytekit.configuration import set_flyte_config_file -from flytekit.configuration.internal import look_up_version_from_image_tag as _look_up_version_from_image_tag, \ - IMAGE as _IMAGE from flytekit.configuration.sdk import WORKFLOW_PACKAGES as _WORKFLOW_PACKAGES -@click.group('pyflyte', invoke_without_command=True) -@click.option('-p', '--project', required=True, type=str, - help='Flyte project to use. You can have more than one project per repo') -@click.option('-d', '--domain', required=True, type=str, help='This is usually development, staging, or production') -@click.option('-c', '--config', required=False, type=str, help='Path to config file for use within container') -@click.option('-k', '--pkgs', required=False, multiple=True, - help='Dot separated python packages to operate on. Multiple may be specified Please note that this ' - 'option will override the option specified in the configuration file, or environment variable') -@click.option('-v', '--version', required=False, type=str, help='This is the version to apply globally for this ' - 'context') -@click.option('-i', '--insecure', required=False, type=bool, help='Do not use SSL to connect to Admin') +@click.group("pyflyte", invoke_without_command=True) +@click.option( + "-p", + "--project", + required=True, + type=str, + help="Flyte project to use. You can have more than one project per repo", +) +@click.option( + "-d", "--domain", required=True, type=str, help="This is usually development, staging, or production", +) +@click.option( + "-c", "--config", required=False, type=str, help="Path to config file for use within container", +) +@click.option( + "-k", + "--pkgs", + required=False, + multiple=True, + help="Dot separated python packages to operate on. Multiple may be specified Please note that this " + "option will override the option specified in the configuration file, or environment variable", +) +@click.option( + "-v", "--version", required=False, type=str, help="This is the version to apply globally for this " "context", +) +@click.option( + "-i", "--insecure", required=False, type=bool, help="Do not use SSL to connect to Admin", +) @click.pass_context def main(ctx, project, domain, config=None, pkgs=None, version=None, insecure=None): """ @@ -59,7 +76,7 @@ def main(ctx, project, domain, config=None, pkgs=None, version=None, insecure=No # says no, let's override the config object by overriding the environment variable. if insecure and not _platform_config.INSECURE.get(): _platform_config.INSECURE.get() - _os.environ[_platform_config.INSECURE.env_var] = 'True' + _os.environ[_platform_config.INSECURE.env_var] = "True" # Handle package management - get from config if not specified on the command line pkgs = pkgs or [] @@ -81,14 +98,19 @@ def update_configuration_file(config_file_path): """ configuration_file = Path(config_file_path or CONFIGURATION_PATH.get()) if configuration_file.is_file(): - click.secho('Using configuration file at {}'.format(configuration_file.absolute().as_posix()), - fg='green') + click.secho( + "Using configuration file at {}".format(configuration_file.absolute().as_posix()), fg="green", + ) set_flyte_config_file(configuration_file.as_posix()) else: - click.secho("Configuration file '{}' could not be loaded. Using values from environment.".format(CONFIGURATION_PATH.get()), - color='yellow') + click.secho( + "Configuration file '{}' could not be loaded. Using values from environment.".format( + CONFIGURATION_PATH.get() + ), + color="yellow", + ) set_flyte_config_file(None) - click.secho('Flyte Admin URL {}'.format(_URL.get()), fg='green') + click.secho("Flyte Admin URL {}".format(_URL.get()), fg="green") main.add_command(register) diff --git a/flytekit/clis/sdk_in_container/register.py b/flytekit/clis/sdk_in_container/register.py index f90b625298..249a45ab95 100644 --- a/flytekit/clis/sdk_in_container/register.py +++ b/flytekit/clis/sdk_in_container/register.py @@ -4,20 +4,23 @@ import click -from flytekit.clis.sdk_in_container.constants import CTX_PROJECT, CTX_DOMAIN, CTX_TEST, CTX_PACKAGES, CTX_VERSION +from flytekit.clis.sdk_in_container.constants import CTX_DOMAIN, CTX_PACKAGES, CTX_PROJECT, CTX_TEST, CTX_VERSION from flytekit.common import utils as _utils from flytekit.common.core import identifier as _identifier from flytekit.common.tasks import task as _task -from flytekit.configuration.internal import look_up_version_from_image_tag as _look_up_version_from_image_tag, \ - IMAGE as _IMAGE +from flytekit.configuration.internal import IMAGE as _IMAGE +from flytekit.configuration.internal import look_up_version_from_image_tag as _look_up_version_from_image_tag from flytekit.tools.module_loader import iterate_registerable_entities_in_order def register_all(project, domain, pkgs, test, version): if test: - click.echo('Test switch enabled, not doing anything...') - click.echo('Running task, workflow, and launch plan registration for {}, {}, {} with version {}'.format( - project, domain, pkgs, version)) + click.echo("Test switch enabled, not doing anything...") + click.echo( + "Running task, workflow, and launch plan registration for {}, {}, {} with version {}".format( + project, domain, pkgs, version + ) + ) # m = module (i.e. python file) # k = value of dir(m), type str @@ -26,13 +29,7 @@ def register_all(project, domain, pkgs, test, version): for m, k, o in iterate_registerable_entities_in_order(pkgs): name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type) _logging.debug("Found module {}\n K: {} Instantiated in {}".format(m, k, o._instantiated_in)) - o._id = _identifier.Identifier( - o.resource_type, - project, - domain, - name, - version - ) + o._id = _identifier.Identifier(o.resource_type, project, domain, name, version) loaded_entities.append(o) for o in loaded_entities: @@ -45,10 +42,9 @@ def register_all(project, domain, pkgs, test, version): def register_tasks_only(project, domain, pkgs, test, version): if test: - click.echo('Test switch enabled, not doing anything...') + click.echo("Test switch enabled, not doing anything...") - click.echo('Running task only registration for {}, {}, {} with version {}'.format( - project, domain, pkgs, version)) + click.echo("Running task only registration for {}, {}, {} with version {}".format(project, domain, pkgs, version)) # Discover all tasks by loading the module for m, k, t in iterate_registerable_entities_in_order(pkgs, include_entities={_task.SdkTask}): @@ -61,10 +57,12 @@ def register_tasks_only(project, domain, pkgs, test, version): t.register(project, domain, name, version) -@click.group('register') +@click.group("register") # --pkgs on the register group is DEPRECATED, use same arg on pyflyte.main instead -@click.option('--pkgs', multiple=True, help="DEPRECATED. This arg can only be used before the 'register' keyword") -@click.option('--test', is_flag=True, help='Dry run, do not actually register with Admin') +@click.option( + "--pkgs", multiple=True, help="DEPRECATED. This arg can only be used before the 'register' keyword", +) +@click.option("--test", is_flag=True, help="Dry run, do not actually register with Admin") @click.pass_context def register(ctx, pkgs=None, test=None): """ @@ -79,9 +77,13 @@ def register(ctx, pkgs=None, test=None): ctx.obj[CTX_TEST] = test -@click.command('tasks') -@click.option('-v', '--version', type=str, help='Version to register tasks with. This is normally parsed from the' - 'image, but you can override here.') +@click.command("tasks") +@click.option( + "-v", + "--version", + type=str, + help="Version to register tasks with. This is normally parsed from the" "image, but you can override here.", +) @click.pass_context def tasks(ctx, version=None): """ @@ -96,9 +98,13 @@ def tasks(ctx, version=None): register_tasks_only(project, domain, pkgs, test, version) -@click.command('workflows') -@click.option('-v', '--version', type=str, help='Version to register tasks with. This is normally parsed from the' - 'image, but you can override here.') +@click.command("workflows") +@click.option( + "-v", + "--version", + type=str, + help="Version to register tasks with. This is normally parsed from the" "image, but you can override here.", +) @click.pass_context def workflows(ctx, version=None): """ diff --git a/flytekit/clis/sdk_in_container/serialize.py b/flytekit/clis/sdk_in_container/serialize.py index 5730bd07c4..a977c42b03 100644 --- a/flytekit/clis/sdk_in_container/serialize.py +++ b/flytekit/clis/sdk_in_container/serialize.py @@ -1,5 +1,4 @@ -from __future__ import absolute_import -from __future__ import print_function +from __future__ import absolute_import, print_function import logging as _logging import math as _math @@ -7,7 +6,7 @@ import click -from flytekit.clis.sdk_in_container.constants import CTX_PACKAGES, CTX_PROJECT, CTX_DOMAIN, CTX_VERSION +from flytekit.clis.sdk_in_container.constants import CTX_DOMAIN, CTX_PACKAGES, CTX_PROJECT, CTX_VERSION from flytekit.common import utils as _utils from flytekit.common.core import identifier as _identifier from flytekit.common.exceptions.scopes import system_entry_point @@ -36,26 +35,20 @@ def serialize_tasks_only(project, domain, pkgs, version, folder=None): for m, k, o in iterate_registerable_entities_in_order(pkgs, include_entities={_sdk_task.SdkTask}): name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type) _logging.debug("Found module {}\n K: {} Instantiated in {}".format(m, k, o._instantiated_in)) - o._id = _identifier.Identifier( - o.resource_type, - project, - domain, - name, - version - ) + o._id = _identifier.Identifier(o.resource_type, project, domain, name, version) loaded_entities.append(o) zero_padded_length = _determine_text_chars(len(loaded_entities)) for i, entity in enumerate(loaded_entities): serialized = entity.serialize() fname_index = str(i).zfill(zero_padded_length) - fname = '{}_{}.pb'.format(fname_index, entity._id.name) - click.echo(' Writing {} to\n {}'.format(entity._id, fname)) + fname = "{}_{}.pb".format(fname_index, entity._id.name) + click.echo(" Writing {} to\n {}".format(entity._id, fname)) if folder: fname = _os.path.join(folder, fname) _write_proto_to_file(serialized, fname) - identifier_fname = '{}_{}.identifier.pb'.format(fname_index, entity._id.name) + identifier_fname = "{}_{}.identifier.pb".format(fname_index, entity._id.name) if folder: identifier_fname = _os.path.join(folder, identifier_fname) _write_proto_to_file(entity._id.to_flyte_idl(), identifier_fname) @@ -92,21 +85,15 @@ def serialize_all(project, domain, pkgs, version, folder=None): for m, k, o in iterate_registerable_entities_in_order(pkgs): name = _utils.fqdn(m.__name__, k, entity_type=o.resource_type) _logging.debug("Found module {}\n K: {} Instantiated in {}".format(m, k, o._instantiated_in)) - o._id = _identifier.Identifier( - o.resource_type, - project, - domain, - name, - version - ) + o._id = _identifier.Identifier(o.resource_type, project, domain, name, version) loaded_entities.append(o) zero_padded_length = _determine_text_chars(len(loaded_entities)) for i, entity in enumerate(loaded_entities): serialized = entity.serialize() fname_index = str(i).zfill(zero_padded_length) - fname = '{}_{}.pb'.format(fname_index, entity._id.name) - click.echo(' Writing {} to\n {}'.format(entity._id, fname)) + fname = "{}_{}.pb".format(fname_index, entity._id.name) + click.echo(" Writing {} to\n {}".format(entity._id, fname)) if folder: fname = _os.path.join(folder, fname) _write_proto_to_file(serialized, fname) @@ -116,7 +103,7 @@ def serialize_all(project, domain, pkgs, version, folder=None): # project/domain, etc.) made for this serialize call. We should not allow users to specify a different project # for instance come registration time, to avoid mismatches between potential internal ids like the TaskTemplate # and the registered entity. - identifier_fname = '{}_{}.identifier.pb'.format(fname_index, entity._id.name) + identifier_fname = "{}_{}.identifier.pb".format(fname_index, entity._id.name) if folder: identifier_fname = _os.path.join(folder, identifier_fname) _write_proto_to_file(entity._id.to_flyte_idl(), identifier_fname) @@ -133,7 +120,7 @@ def _determine_text_chars(length): return _math.ceil(_math.log(length, 10)) -@click.group('serialize') +@click.group("serialize") @click.pass_context def serialize(ctx): """ @@ -143,13 +130,17 @@ def serialize(ctx): object contains the WorkflowTemplate, along with the relevant tasks for that workflow. In lieu of Admin, this serialization step will set the URN of the tasks to the fully qualified name of the task function. """ - click.echo('Serializing Flyte elements with image {}'.format(_internal_configuration.IMAGE.get())) + click.echo("Serializing Flyte elements with image {}".format(_internal_configuration.IMAGE.get())) -@click.command('tasks') -@click.option('-v', '--version', type=str, help='Version to serialize tasks with. This is normally parsed from the' - 'image, but you can override here.') -@click.option('-f', '--folder', type=click.Path(exists=True)) +@click.command("tasks") +@click.option( + "-v", + "--version", + type=str, + help="Version to serialize tasks with. This is normally parsed from the" "image, but you can override here.", +) +@click.option("-f", "--folder", type=click.Path(exists=True)) @click.pass_context def tasks(ctx, version=None, folder=None): project = ctx.obj[CTX_PROJECT] @@ -159,32 +150,40 @@ def tasks(ctx, version=None, folder=None): if folder: click.echo(f"Writing output to {folder}") - version = version or ctx.obj[CTX_VERSION] or _internal_configuration.look_up_version_from_image_tag( - _internal_configuration.IMAGE.get()) + version = ( + version + or ctx.obj[CTX_VERSION] + or _internal_configuration.look_up_version_from_image_tag(_internal_configuration.IMAGE.get()) + ) internal_settings = { - 'project': project, - 'domain': domain, - 'version': version, + "project": project, + "domain": domain, + "version": version, } # Populate internal settings for project/domain/version from the environment so that the file names are resolved # with the correct strings. The file itself doesn't need to change though. with TemporaryConfiguration(_internal_configuration.CONFIGURATION_PATH.get(), internal_settings): - _logging.debug("Serializing with settings\n" - "\n Project: {}" - "\n Domain: {}" - "\n Version: {}" - "\n\nover the following packages {}".format(project, domain, version, pkgs) - ) + _logging.debug( + "Serializing with settings\n" + "\n Project: {}" + "\n Domain: {}" + "\n Version: {}" + "\n\nover the following packages {}".format(project, domain, version, pkgs) + ) serialize_tasks_only(project, domain, pkgs, version, folder) -@click.command('workflows') -@click.option('-v', '--version', type=str, help='Version to serialize tasks with. This is normally parsed from the' - 'image, but you can override here.') +@click.command("workflows") +@click.option( + "-v", + "--version", + type=str, + help="Version to serialize tasks with. This is normally parsed from the" "image, but you can override here.", +) # For now let's just assume that the directory needs to exist. If you're docker run -v'ing, docker will create the # directory for you so it shouldn't be a problem. -@click.option('-f', '--folder', type=click.Path(exists=True)) +@click.option("-f", "--folder", type=click.Path(exists=True)) @click.pass_context def workflows(ctx, version=None, folder=None): _logging.getLogger().setLevel(_logging.DEBUG) @@ -196,23 +195,27 @@ def workflows(ctx, version=None, folder=None): domain = ctx.obj[CTX_DOMAIN] pkgs = ctx.obj[CTX_PACKAGES] - version = version or ctx.obj[CTX_VERSION] or _internal_configuration.look_up_version_from_image_tag( - _internal_configuration.IMAGE.get()) + version = ( + version + or ctx.obj[CTX_VERSION] + or _internal_configuration.look_up_version_from_image_tag(_internal_configuration.IMAGE.get()) + ) internal_settings = { - 'project': project, - 'domain': domain, - 'version': version, + "project": project, + "domain": domain, + "version": version, } # Populate internal settings for project/domain/version from the environment so that the file names are resolved # with the correct strings. The file itself doesn't need to change though. with TemporaryConfiguration(_internal_configuration.CONFIGURATION_PATH.get(), internal_settings): - _logging.debug("Serializing with settings\n" - "\n Project: {}" - "\n Domain: {}" - "\n Version: {}" - "\n\nover the following packages {}".format(project, domain, version, pkgs) - ) + _logging.debug( + "Serializing with settings\n" + "\n Project: {}" + "\n Domain: {}" + "\n Version: {}" + "\n\nover the following packages {}".format(project, domain, version, pkgs) + ) serialize_all(project, domain, pkgs, version, folder) diff --git a/flytekit/common/component_nodes.py b/flytekit/common/component_nodes.py index bd65b23e74..3a8f1d3d65 100644 --- a/flytekit/common/component_nodes.py +++ b/flytekit/common/component_nodes.py @@ -1,15 +1,15 @@ from __future__ import absolute_import -import six as _six import logging as _logging +import six as _six + from flytekit.common import sdk_bases as _sdk_bases from flytekit.common.exceptions import system as _system_exceptions from flytekit.models.core import workflow as _workflow_model class SdkTaskNode(_six.with_metaclass(_sdk_bases.ExtendedSdkType, _workflow_model.TaskNode)): - def __init__(self, sdk_task): """ :param flytekit.common.tasks.task.SdkTask sdk_task: @@ -43,6 +43,7 @@ def promote_from_model(cls, base_model, tasks): :rtype: SdkTaskNode """ from flytekit.common.tasks import task as _task + if base_model.reference_id in tasks: t = tasks[base_model.reference_id] _logging.debug("Found existing task template for {}, will not retrieve from Admin".format(t.id)) @@ -66,9 +67,12 @@ def __init__(self, sdk_workflow=None, sdk_launch_plan=None): :param flytekit.common.launch_plan.SdkLaunchPlan sdk_launch_plan: """ if sdk_workflow and sdk_launch_plan: - raise _system_exceptions.FlyteSystemException("SdkWorkflowNode cannot be called with both a workflow and " - "a launchplan specified, please pick one. WF: {} LP: {}", - sdk_workflow, sdk_launch_plan) + raise _system_exceptions.FlyteSystemException( + "SdkWorkflowNode cannot be called with both a workflow and " + "a launchplan specified, please pick one. WF: {} LP: {}", + sdk_workflow, + sdk_launch_plan, + ) self._sdk_workflow = sdk_workflow self._sdk_launch_plan = sdk_launch_plan @@ -124,7 +128,8 @@ def promote_from_model(cls, base_model, sub_workflows, tasks): :rtype: SdkWorkflowNode """ # put the import statement here to prevent circular dependency error - from flytekit.common import workflow as _workflow, launch_plan as _launch_plan + from flytekit.common import launch_plan as _launch_plan + from flytekit.common import workflow as _workflow project = base_model.reference.project domain = base_model.reference.domain @@ -137,18 +142,19 @@ def promote_from_model(cls, base_model, sub_workflows, tasks): # The workflow templates for sub-workflows should have been included in the original response if base_model.reference in sub_workflows: sw = sub_workflows[base_model.reference] - promoted = _workflow.SdkWorkflow.promote_from_model(sw, sub_workflows=sub_workflows, - tasks=tasks) + promoted = _workflow.SdkWorkflow.promote_from_model(sw, sub_workflows=sub_workflows, tasks=tasks) return cls(sdk_workflow=promoted) # If not found for some reason, fetch it from Admin again. # The reason there is a warning here but not for tasks is because sub-workflows should always be passed # along. Ideally subworkflows are never even registered with Admin, so fetching from Admin ideally doesn't # return anything. - _logging.warning("Your subworkflow with id {} is not included in the promote call.".format( - base_model.reference)) + _logging.warning( + "Your subworkflow with id {} is not included in the promote call.".format(base_model.reference) + ) sdk_workflow = _workflow.SdkWorkflow.fetch(project, domain, name, version) return cls(sdk_workflow=sdk_workflow) else: - raise _system_exceptions.FlyteSystemException("Bad workflow node model, neither subworkflow nor " - "launchplan specified.") + raise _system_exceptions.FlyteSystemException( + "Bad workflow node model, neither subworkflow nor " "launchplan specified." + ) diff --git a/flytekit/common/constants.py b/flytekit/common/constants.py index 8f3af75de3..5fb6fa5dbb 100644 --- a/flytekit/common/constants.py +++ b/flytekit/common/constants.py @@ -1,9 +1,9 @@ from __future__ import absolute_import -INPUT_FILE_NAME = 'inputs.pb' -OUTPUT_FILE_NAME = 'outputs.pb' -FUTURES_FILE_NAME = 'futures.pb' -ERROR_FILE_NAME = 'error.pb' +INPUT_FILE_NAME = "inputs.pb" +OUTPUT_FILE_NAME = "outputs.pb" +FUTURES_FILE_NAME = "futures.pb" +ERROR_FILE_NAME = "error.pb" class SdkTaskType(object): @@ -27,7 +27,8 @@ class SdkTaskType(object): SAGEMAKER_TRAINING_JOB_TASK = "sagemaker_training_job_task" SAGEMAKER_HYPERPARAMETER_TUNING_JOB_TASK = "sagemaker_hyperparameter_tuning_job_task" -GLOBAL_INPUT_NODE_ID = '' + +GLOBAL_INPUT_NODE_ID = "" START_NODE_ID = "start-node" END_NODE_ID = "end-node" diff --git a/flytekit/common/core/identifier.py b/flytekit/common/core/identifier.py index 0d2133cae1..1ef112215a 100644 --- a/flytekit/common/core/identifier.py +++ b/flytekit/common/core/identifier.py @@ -1,8 +1,10 @@ from __future__ import absolute_import -from flytekit.models.core import identifier as _core_identifier + +import six as _six + from flytekit.common import sdk_bases as _sdk_bases from flytekit.common.exceptions import user as _user_exceptions -import six as _six +from flytekit.models.core import identifier as _core_identifier class Identifier(_six.with_metaclass(_sdk_bases.ExtendedSdkType, _core_identifier.Identifier)): @@ -21,11 +23,7 @@ def promote_from_model(cls, base_model): :rtype: Identifier """ return cls( - base_model.resource_type, - base_model.project, - base_model.domain, - base_model.name, - base_model.version + base_model.resource_type, base_model.project, base_model.domain, base_model.name, base_model.version, ) @classmethod @@ -47,28 +45,19 @@ def from_python_std(cls, string): if resource_type not in cls._STRING_TO_TYPE_MAP: raise _user_exceptions.FlyteValueException( "The provided string could not be parsed. The first element of an identifier must be one of: {}. " - "Received: {}".format( - list(cls._STRING_TO_TYPE_MAP.keys()), - resource_type - ) + "Received: {}".format(list(cls._STRING_TO_TYPE_MAP.keys()), resource_type) ) resource_type = cls._STRING_TO_TYPE_MAP[resource_type] - return cls( - resource_type, - project, - domain, - name, - version - ) + return cls(resource_type, project, domain, name, version) def __str__(self): return "{}:{}:{}:{}:{}".format( - type(self)._TYPE_TO_STRING_MAP.get(self.resource_type, ''), + type(self)._TYPE_TO_STRING_MAP.get(self.resource_type, ""), self.project, self.domain, self.name, - self.version + self.version, ) @@ -81,11 +70,7 @@ def promote_from_model(cls, base_model): :param flytekit.models.core.identifier.WorkflowExecutionIdentifier base_model: :rtype: WorkflowExecutionIdentifier """ - return cls( - base_model.project, - base_model.domain, - base_model.name, - ) + return cls(base_model.project, base_model.domain, base_model.name,) @classmethod def from_python_std(cls, string): @@ -99,7 +84,7 @@ def from_python_std(cls, string): raise _user_exceptions.FlyteValueException( string, "The provided string was not in a parseable format. The string for an identifier must be in the format" - " ex:project:domain:name." + " ex:project:domain:name.", ) resource_type, project, domain, name = segments @@ -107,21 +92,13 @@ def from_python_std(cls, string): if resource_type != "ex": raise _user_exceptions.FlyteValueException( resource_type, - "The provided string could not be parsed. The first element of an execution identifier must be 'ex'." + "The provided string could not be parsed. The first element of an execution identifier must be 'ex'.", ) - return cls( - project, - domain, - name, - ) + return cls(project, domain, name,) def __str__(self): - return "ex:{}:{}:{}".format( - self.project, - self.domain, - self.name - ) + return "ex:{}:{}:{}".format(self.project, self.domain, self.name) class TaskExecutionIdentifier( @@ -136,7 +113,7 @@ def promote_from_model(cls, base_model): return cls( task_id=base_model.task_id, node_execution_id=base_model.node_execution_id, - retry_attempt=base_model.retry_attempt + retry_attempt=base_model.retry_attempt, ) @classmethod @@ -151,7 +128,7 @@ def from_python_std(cls, string): raise _user_exceptions.FlyteValueException( string, "The provided string was not in a parseable format. The string for an identifier must be in the format" - " te:exec_project:exec_domain:exec_name:node_id:task_project:task_domain:task_name:task_version:retry." + " te:exec_project:exec_domain:exec_name:node_id:task_project:task_domain:task_name:task_version:retry.", ) resource_type, ep, ed, en, node_id, tp, td, tn, tv, retry = segments @@ -159,14 +136,13 @@ def from_python_std(cls, string): if resource_type != "te": raise _user_exceptions.FlyteValueException( resource_type, - "The provided string could not be parsed. The first element of an execution identifier must be 'ex'." + "The provided string could not be parsed. The first element of an execution identifier must be 'ex'.", ) return cls( task_id=Identifier(_core_identifier.ResourceType.TASK, tp, td, tn, tv), node_execution_id=_core_identifier.NodeExecutionIdentifier( - node_id=node_id, - execution_id=_core_identifier.WorkflowExecutionIdentifier(ep, ed, en), + node_id=node_id, execution_id=_core_identifier.WorkflowExecutionIdentifier(ep, ed, en), ), retry_attempt=int(retry), ) diff --git a/flytekit/common/exceptions/base.py b/flytekit/common/exceptions/base.py index 018902ac1f..a8c1087556 100644 --- a/flytekit/common/exceptions/base.py +++ b/flytekit/common/exceptions/base.py @@ -1,4 +1,5 @@ from __future__ import absolute_import + import six as _six @@ -9,7 +10,7 @@ def error_code(cls): class FlyteException(_six.with_metaclass(_FlyteCodedExceptionMetaclass, Exception)): - _ERROR_CODE = 'UnknownFlyteException' + _ERROR_CODE = "UnknownFlyteException" class FlyteRecoverableException(FlyteException): diff --git a/flytekit/common/exceptions/scopes.py b/flytekit/common/exceptions/scopes.py index 456fe1fdc4..671e045db5 100644 --- a/flytekit/common/exceptions/scopes.py +++ b/flytekit/common/exceptions/scopes.py @@ -1,15 +1,18 @@ from __future__ import absolute_import -from six import reraise as _reraise -from wrapt import decorator as _decorator from sys import exc_info as _exc_info -from flytekit.common.exceptions import system as _system_exceptions, user as _user_exceptions, base as _base_exceptions -from flytekit.models.core import errors as _error_model from traceback import format_tb as _format_tb +from six import reraise as _reraise +from wrapt import decorator as _decorator -class FlyteScopedException(Exception): +from flytekit.common.exceptions import base as _base_exceptions +from flytekit.common.exceptions import system as _system_exceptions +from flytekit.common.exceptions import user as _user_exceptions +from flytekit.models.core import errors as _error_model + +class FlyteScopedException(Exception): def __init__(self, context, exc_type, exc_value, exc_tb, top_trim=0, bottom_trim=0, kind=None): self._exc_type = exc_type self._exc_value = exc_value @@ -36,16 +39,10 @@ def verbose_message(self): lines = _format_tb(top_tb, limit=limit) lines = [line.rstrip() for line in lines] - lines = ('\n'.join(lines).split('\n')) - traceback_str = '\n '.join([""] + lines) - - format_str = ( - "Traceback (most recent call last):\n" - "{traceback}\n" - "\n" - "Message:\n" - "\n" - " {message}") + lines = "\n".join(lines).split("\n") + traceback_str = "\n ".join([""] + lines) + + format_str = "Traceback (most recent call last):\n" "{traceback}\n" "\n" "Message:\n" "\n" " {message}" return format_str.format(traceback=traceback_str, message=str(self.value)) def __str__(self): @@ -101,11 +98,8 @@ def kind(self): class FlyteScopedSystemException(FlyteScopedException): - def __init__(self, exc_type, exc_value, exc_tb, **kwargs): - super(FlyteScopedSystemException, self).__init__( - "SYSTEM", exc_type, exc_value, exc_tb, **kwargs - ) + super(FlyteScopedSystemException, self).__init__("SYSTEM", exc_type, exc_value, exc_tb, **kwargs) @property def verbose_message(self): @@ -118,11 +112,8 @@ def verbose_message(self): class FlyteScopedUserException(FlyteScopedException): - def __init__(self, exc_type, exc_value, exc_tb, **kwargs): - super(FlyteScopedUserException, self).__init__( - "USER", exc_type, exc_value, exc_tb, **kwargs - ) + super(FlyteScopedUserException, self).__init__("USER", exc_type, exc_value, exc_tb, **kwargs) @property def verbose_message(self): @@ -170,15 +161,15 @@ def system_entry_point(wrapped, instance, args, kwargs): except _user_exceptions.FlyteUserException: # Re-raise from here. _reraise( - FlyteScopedUserException, - FlyteScopedUserException(*_exc_info()), - _exc_info()[2]) - except: + FlyteScopedUserException, FlyteScopedUserException(*_exc_info()), _exc_info()[2], + ) + except Exception: # System error, raise full stack-trace all the way up the chain. _reraise( FlyteScopedSystemException, FlyteScopedSystemException(*_exc_info(), kind=_error_model.ContainerError.Kind.RECOVERABLE), - _exc_info()[2]) + _exc_info()[2], + ) finally: _CONTEXT_STACK.pop() @@ -209,20 +200,17 @@ def user_entry_point(wrapped, instance, args, kwargs): _reraise(*_exc_info()) except _user_exceptions.FlyteUserException: _reraise( - FlyteScopedUserException, - FlyteScopedUserException(*_exc_info()), - _exc_info()[2]) + FlyteScopedUserException, FlyteScopedUserException(*_exc_info()), _exc_info()[2], + ) except _system_exceptions.FlyteSystemException: _reraise( - FlyteScopedSystemException, - FlyteScopedSystemException(*_exc_info()), - _exc_info()[2]) - except: + FlyteScopedSystemException, FlyteScopedSystemException(*_exc_info()), _exc_info()[2], + ) + except Exception: # Any non-platform raised exception is a user exception. # This will also catch FlyteUserException re-raised by the system_entry_point handler _reraise( - FlyteScopedUserException, - FlyteScopedUserException(*_exc_info()), - _exc_info()[2]) + FlyteScopedUserException, FlyteScopedUserException(*_exc_info()), _exc_info()[2], + ) finally: _CONTEXT_STACK.pop() diff --git a/flytekit/common/exceptions/system.py b/flytekit/common/exceptions/system.py index 549c5a9863..a3787f8234 100644 --- a/flytekit/common/exceptions/system.py +++ b/flytekit/common/exceptions/system.py @@ -1,4 +1,5 @@ from __future__ import absolute_import + from flytekit.common.exceptions import base as _base_exceptions @@ -18,23 +19,21 @@ def _create_verbose_message(cls, task_module, task_name=None, additional_msg=Non if task_name is None: return "Entrypoint is not loadable! Could not load the module: '{task_module}'{additional_msg}".format( task_module=task_module, - additional_msg=" due to error: {}".format(additional_msg) if additional_msg is not None else "." + additional_msg=" due to error: {}".format(additional_msg) if additional_msg is not None else ".", ) else: - return "Entrypoint is not loadable! Could not find the task: '{task_name}' in '{task_module}'" \ - "{additional_msg}".format( - task_module=task_module, - task_name=task_name, - additional_msg="." if additional_msg is None else " due to error: {}".format(additional_msg) - ) + return ( + "Entrypoint is not loadable! Could not find the task: '{task_name}' in '{task_module}'" + "{additional_msg}".format( + task_module=task_module, + task_name=task_name, + additional_msg="." if additional_msg is None else " due to error: {}".format(additional_msg), + ) + ) def __init__(self, task_module, task_name=None, additional_msg=None): super(FlyteSystemException, self).__init__( - self._create_verbose_message( - task_module, - task_name=task_name, - additional_msg=additional_msg - ) + self._create_verbose_message(task_module, task_name=task_name, additional_msg=additional_msg) ) diff --git a/flytekit/common/exceptions/user.py b/flytekit/common/exceptions/user.py index 0d7ebad590..102e94413c 100644 --- a/flytekit/common/exceptions/user.py +++ b/flytekit/common/exceptions/user.py @@ -1,5 +1,7 @@ from __future__ import absolute_import -from flytekit.common.exceptions.base import FlyteException as _FlyteException, FlyteRecoverableException as _Recoverable + +from flytekit.common.exceptions.base import FlyteException as _FlyteException +from flytekit.common.exceptions.base import FlyteRecoverableException as _Recoverable class FlyteUserException(_FlyteException): @@ -19,23 +21,22 @@ def _create_verbose_message(cls, received_type, expected_type, received_value=No return "Type error! Received: {} with value: {}, Expected{}: {}. {}".format( received_type, received_value, - ' one of' if FlyteTypeException._is_a_container(expected_type) else '', + " one of" if FlyteTypeException._is_a_container(expected_type) else "", expected_type, - additional_msg or '') + additional_msg or "", + ) else: return "Type error! Received: {}, Expected{}: {}. {}".format( received_type, - ' one of' if FlyteTypeException._is_a_container(expected_type) else '', + " one of" if FlyteTypeException._is_a_container(expected_type) else "", expected_type, - additional_msg or '') + additional_msg or "", + ) def __init__(self, received_type, expected_type, additional_msg=None, received_value=None): super(FlyteTypeException, self).__init__( self._create_verbose_message( - received_type, - expected_type, - received_value=received_value, - additional_msg=additional_msg + received_type, expected_type, received_value=received_value, additional_msg=additional_msg, ) ) @@ -49,9 +50,7 @@ def _create_verbose_message(cls, received_value, error_message): return "Value error! Received: {}. {}".format(received_value, error_message) def __init__(self, received_value, error_message): - super(FlyteValueException, self).__init__( - self._create_verbose_message(received_value, error_message) - ) + super(FlyteValueException, self).__init__(self._create_verbose_message(received_value, error_message)) class FlyteAssertion(FlyteUserException, AssertionError): diff --git a/flytekit/common/interface.py b/flytekit/common/interface.py index 5001581ca5..37fe4b64fa 100644 --- a/flytekit/common/interface.py +++ b/flytekit/common/interface.py @@ -2,14 +2,17 @@ import six as _six -from flytekit.common import sdk_bases as _sdk_bases, promise as _promise +from flytekit.common import promise as _promise +from flytekit.common import sdk_bases as _sdk_bases from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.types import helpers as _type_helpers, containers as _containers, primitives as _primitives -from flytekit.models import interface as _interface_models, literals as _literal_models +from flytekit.common.types import containers as _containers +from flytekit.common.types import helpers as _type_helpers +from flytekit.common.types import primitives as _primitives +from flytekit.models import interface as _interface_models +from flytekit.models import literals as _literal_models class BindingData(_six.with_metaclass(_sdk_bases.ExtendedSdkType, _literal_models.BindingData)): - @staticmethod def _has_sub_bindings(m): """ @@ -29,9 +32,7 @@ def promote_from_model(cls, model): :param flytekit.models.literals.BindingData model: :rtype: BindingData """ - return cls( - scalar=model.scalar, collection=model.collection, promise=model.promise, map=model.map - ) + return cls(scalar=model.scalar, collection=model.collection, promise=model.promise, map=model.map,) @classmethod def from_python_std(cls, literal_type, t_value, upstream_nodes=None): @@ -52,7 +53,7 @@ def from_python_std(cls, literal_type, t_value, upstream_nodes=None): _user_exceptions.FlyteTypeException( t_value.sdk_type, downstream_sdk_type, - additional_msg="When binding workflow input: {}".format(t_value) + additional_msg="When binding workflow input: {}".format(t_value), ) promise = t_value.promise elif isinstance(t_value, _promise.NodeOutput): @@ -60,7 +61,7 @@ def from_python_std(cls, literal_type, t_value, upstream_nodes=None): _user_exceptions.FlyteTypeException( t_value.sdk_type, downstream_sdk_type, - additional_msg="When binding node output: {}".format(t_value) + additional_msg="When binding node output: {}".format(t_value), ) promise = t_value if upstream_nodes is not None: @@ -71,20 +72,19 @@ def from_python_std(cls, literal_type, t_value, upstream_nodes=None): type(t_value), downstream_sdk_type, received_value=t_value, - additional_msg="Cannot bind a list to a non-list type." + additional_msg="Cannot bind a list to a non-list type.", ) collection = _literal_models.BindingDataCollection( [ BindingData.from_python_std( - downstream_sdk_type.sub_type.to_flyte_literal_type(), - v, - upstream_nodes=upstream_nodes + downstream_sdk_type.sub_type.to_flyte_literal_type(), v, upstream_nodes=upstream_nodes, ) for v in t_value ] ) - elif isinstance(t_value, dict) and \ - (not issubclass(downstream_sdk_type, _primitives.Generic) or BindingData._has_sub_bindings(t_value)): + elif isinstance(t_value, dict) and ( + not issubclass(downstream_sdk_type, _primitives.Generic) or BindingData._has_sub_bindings(t_value) + ): # TODO: This behavior should be embedded in the type engine. Someone should be able to alter behavior of # TODO: binding logic by injecting their own type engine. The same goes for the list check above. raise NotImplementedError("TODO: Cannot use map bindings at the moment") @@ -97,7 +97,6 @@ def from_python_std(cls, literal_type, t_value, upstream_nodes=None): class TypedInterface(_six.with_metaclass(_sdk_bases.ExtendedSdkType, _interface_models.TypedInterface)): - @classmethod def promote_from_model(cls, model): """ @@ -118,14 +117,10 @@ def create_bindings_for_inputs(self, map_of_bindings): for k in sorted(self.inputs): var = self.inputs[k] if k not in map_of_bindings: - raise _user_exceptions.FlyteAssertion( - "Input was not specified for: {} of type {}".format(k, var.type) - ) + raise _user_exceptions.FlyteAssertion("Input was not specified for: {} of type {}".format(k, var.type)) binding_data[k] = BindingData.from_python_std( - var.type, - map_of_bindings[k], - upstream_nodes=all_upstream_nodes + var.type, map_of_bindings[k], upstream_nodes=all_upstream_nodes ) extra_inputs = set(binding_data.keys()) ^ set(map_of_bindings.keys()) @@ -141,7 +136,10 @@ def create_bindings_for_inputs(self, map_of_bindings): seen_nodes.add(n) min_upstream.append(n) - return [_literal_models.Binding(k, bd) for k, bd in _six.iteritems(binding_data)], min_upstream + return ( + [_literal_models.Binding(k, bd) for k, bd in _six.iteritems(binding_data)], + min_upstream, + ) def __repr__(self): return "({inputs}) -> ({outputs})".format( @@ -156,5 +154,5 @@ def __repr__(self): "{}: {}".format(k, _type_helpers.get_sdk_type_from_literal_type(v.type)) for k, v in _six.iteritems(self.outputs) ] - ) + ), ) diff --git a/flytekit/common/launch_plan.py b/flytekit/common/launch_plan.py index d2b9c4bab4..a6d217e710 100644 --- a/flytekit/common/launch_plan.py +++ b/flytekit/common/launch_plan.py @@ -1,29 +1,39 @@ from __future__ import absolute_import -from flytekit.common import sdk_bases as _sdk_bases, promise as _promises, interface as _interface, nodes as _nodes, \ - workflow_execution as _workflow_execution -from flytekit.common.core import identifier as _identifier -from flytekit.common.exceptions import scopes as _exception_scopes, user as _user_exceptions - -from flytekit.common.mixins import registerable as _registerable, hash as _hash_mixin, launchable as _launchable_mixin -from flytekit.common.types import helpers as _type_helpers -from flytekit.configuration import sdk as _sdk_config, auth as _auth_config -from flytekit.engines import loader as _engine_loader -from flytekit.models import launch_plan as _launch_plan_models, schedule as _schedule_model, interface as \ - _interface_models, literals as _literal_models, common as _common_models -from flytekit.models.core import identifier as _identifier_model, workflow as _workflow_models import datetime as _datetime -from deprecated import deprecated as _deprecated import logging as _logging -import six as _six import uuid as _uuid +import six as _six +from deprecated import deprecated as _deprecated + +from flytekit.common import interface as _interface +from flytekit.common import nodes as _nodes +from flytekit.common import promise as _promises +from flytekit.common import sdk_bases as _sdk_bases +from flytekit.common import workflow_execution as _workflow_execution +from flytekit.common.core import identifier as _identifier +from flytekit.common.exceptions import scopes as _exception_scopes +from flytekit.common.exceptions import user as _user_exceptions +from flytekit.common.mixins import hash as _hash_mixin +from flytekit.common.mixins import launchable as _launchable_mixin +from flytekit.common.mixins import registerable as _registerable +from flytekit.common.types import helpers as _type_helpers +from flytekit.configuration import auth as _auth_config +from flytekit.configuration import sdk as _sdk_config +from flytekit.engines import loader as _engine_loader +from flytekit.models import common as _common_models +from flytekit.models import interface as _interface_models +from flytekit.models import launch_plan as _launch_plan_models +from flytekit.models import literals as _literal_models +from flytekit.models import schedule as _schedule_model +from flytekit.models.core import identifier as _identifier_model +from flytekit.models.core import workflow as _workflow_models + class SdkLaunchPlan( _six.with_metaclass( - _sdk_bases.ExtendedSdkType, - _launch_plan_models.LaunchPlanSpec, - _launchable_mixin.LaunchableEntity, + _sdk_bases.ExtendedSdkType, _launch_plan_models.LaunchPlanSpec, _launchable_mixin.LaunchableEntity, ) ): def __init__(self, *args, **kwargs): @@ -69,6 +79,7 @@ def fetch(cls, project, domain, name, version=None): :rtype: SdkLaunchPlan """ from flytekit.common import workflow as _workflow + launch_plan_id = _identifier.Identifier( _identifier_model.ResourceType.LAUNCH_PLAN, project, domain, name, version ) @@ -107,19 +118,22 @@ def auth_role(self): :rtype: flytekit.models.common.AuthRole """ fixed_auth = super(SdkLaunchPlan, self).auth_role - if fixed_auth is not None and \ - (fixed_auth.assumable_iam_role is not None or fixed_auth.kubernetes_service_account is not None): - return fixed_auth + if fixed_auth is not None and ( + fixed_auth.assumable_iam_role is not None or fixed_auth.kubernetes_service_account is not None + ): + return fixed_auth assumable_iam_role = _auth_config.ASSUMABLE_IAM_ROLE.get() kubernetes_service_account = _auth_config.KUBERNETES_SERVICE_ACCOUNT.get() if not (assumable_iam_role or kubernetes_service_account): - _logging.warning("Using deprecated `role` from config. " - "Please update your config to use `assumable_iam_role` instead") + _logging.warning( + "Using deprecated `role` from config. " "Please update your config to use `assumable_iam_role` instead" + ) assumable_iam_role = _sdk_config.ROLE.get() - return _common_models.AuthRole(assumable_iam_role=assumable_iam_role, - kubernetes_service_account=kubernetes_service_account) + return _common_models.AuthRole( + assumable_iam_role=assumable_iam_role, kubernetes_service_account=kubernetes_service_account, + ) @property def interface(self): @@ -149,7 +163,7 @@ def raw_output_data_config(self): :rtype: flytekit.models.common.RawOutputDataConfig """ raw_output_data_config = super(SdkLaunchPlan, self).raw_output_data_config - if raw_output_data_config is not None and raw_output_data_config.output_location_prefix != '': + if raw_output_data_config is not None and raw_output_data_config.output_location_prefix != "": return raw_output_data_config # If it was not set explicitly then let's use the value found in the configuration. @@ -180,24 +194,38 @@ def _python_std_input_map_to_literal_map(self, inputs): """ return _type_helpers.pack_python_std_map_to_literal_map( inputs, - { - k: user_input.sdk_type - for k, user_input in _six.iteritems(self.default_inputs.parameters) if k in inputs - } + {k: user_input.sdk_type for k, user_input in _six.iteritems(self.default_inputs.parameters) if k in inputs}, ) - @_deprecated(reason="Use launch_with_literals instead", version='0.9.0') - def execute_with_literals(self, project, domain, literal_inputs, name=None, notification_overrides=None, - label_overrides=None, annotation_overrides=None): + @_deprecated(reason="Use launch_with_literals instead", version="0.9.0") + def execute_with_literals( + self, + project, + domain, + literal_inputs, + name=None, + notification_overrides=None, + label_overrides=None, + annotation_overrides=None, + ): """ Deprecated. """ - return self.launch_with_literals(project, domain, literal_inputs, name, notification_overrides, label_overrides, - annotation_overrides) + return self.launch_with_literals( + project, domain, literal_inputs, name, notification_overrides, label_overrides, annotation_overrides, + ) @_exception_scopes.system_entry_point - def launch_with_literals(self, project, domain, literal_inputs, name=None, notification_overrides=None, - label_overrides=None, annotation_overrides=None): + def launch_with_literals( + self, + project, + domain, + literal_inputs, + name=None, + notification_overrides=None, + label_overrides=None, + annotation_overrides=None, + ): """ Executes the launch plan and returns the execution identifier. This version of execution is meant for when you already have a LiteralMap of inputs. @@ -216,14 +244,18 @@ def launch_with_literals(self, project, domain, literal_inputs, name=None, notif """ # Kubernetes requires names starting with an alphabet for some resources. name = name or "f" + _uuid.uuid4().hex[:19] - execution = _engine_loader.get_engine().get_launch_plan(self).launch( - project, - domain, - name, - literal_inputs, - notification_overrides=notification_overrides, - label_overrides=label_overrides, - annotation_overrides=annotation_overrides, + execution = ( + _engine_loader.get_engine() + .get_launch_plan(self) + .launch( + project, + domain, + name, + literal_inputs, + notification_overrides=notification_overrides, + label_overrides=label_overrides, + annotation_overrides=annotation_overrides, + ) ) return _workflow_execution.SdkWorkflowExecution.promote_from_model(execution) @@ -237,14 +269,11 @@ def __call__(self, *args, **input_map): if len(args) > 0: raise _user_exceptions.FlyteAssertion( "When adding a launchplan as a node in a workflow, all inputs must be specified with kwargs only. We " - "detected {} positional args.".format(self, len(args)) + "detected {} positional args.".format(len(args)) ) # Take the default values from the launch plan - default_inputs = { - k: v.sdk_default - for k, v in _six.iteritems(self.default_inputs.parameters) if not v.required - } + default_inputs = {k: v.sdk_default for k, v in _six.iteritems(self.default_inputs.parameters) if not v.required} default_inputs.update(input_map) bindings, upstream_nodes = self.interface.create_bindings_for_inputs(default_inputs) @@ -254,7 +283,7 @@ def __call__(self, *args, **input_map): metadata=_workflow_models.NodeMetadata("", _datetime.timedelta(), _literal_models.RetryStrategy(0)), bindings=sorted(bindings, key=lambda b: b.var), upstream_nodes=upstream_nodes, - sdk_launch_plan=self + sdk_launch_plan=self, ) def __repr__(self): @@ -267,22 +296,20 @@ def __repr__(self): # The difference between this and the SdkLaunchPlan class is that this runnable class is supposed to only be used for # launch plans loaded alongside the current Python interpreter. class SdkRunnableLaunchPlan( - _hash_mixin.HashOnReferenceMixin, - SdkLaunchPlan, - _registerable.RegisterableEntity, + _hash_mixin.HashOnReferenceMixin, SdkLaunchPlan, _registerable.RegisterableEntity, ): def __init__( - self, - sdk_workflow, - default_inputs=None, - fixed_inputs=None, - role=None, - schedule=None, - notifications=None, - labels=None, - annotations=None, - auth_role=None, - raw_output_data_config=None, + self, + sdk_workflow, + default_inputs=None, + fixed_inputs=None, + role=None, + schedule=None, + notifications=None, + labels=None, + annotations=None, + auth_role=None, + raw_output_data_config=None, ): """ :param flytekit.common.workflow.SdkWorkflow sdk_workflow: @@ -318,25 +345,24 @@ def __init__( super(SdkRunnableLaunchPlan, self).__init__( None, _launch_plan_models.LaunchPlanMetadata( - schedule=schedule or _schedule_model.Schedule(''), - notifications=notifications or [] + schedule=schedule or _schedule_model.Schedule(""), notifications=notifications or [], ), _interface_models.ParameterMap(default_inputs), _type_helpers.pack_python_std_map_to_literal_map( fixed_inputs, { k: _type_helpers.get_sdk_type_from_literal_type(var.type) - for k, var in _six.iteritems(sdk_workflow.interface.inputs) if k in fixed_inputs - } + for k, var in _six.iteritems(sdk_workflow.interface.inputs) + if k in fixed_inputs + }, ), labels or _common_models.Labels({}), annotations or _common_models.Annotations({}), auth_role, - raw_output_data_config or _common_models.RawOutputDataConfig(''), + raw_output_data_config or _common_models.RawOutputDataConfig(""), ) self._interface = _interface.TypedInterface( - {k: v.var for k, v in _six.iteritems(default_inputs)}, - sdk_workflow.interface.outputs + {k: v.var for k, v in _six.iteritems(default_inputs)}, sdk_workflow.interface.outputs, ) self._upstream_entities = {sdk_workflow} self._sdk_workflow = sdk_workflow @@ -351,11 +377,7 @@ def register(self, project, domain, name, version): """ self.validate() id_to_register = _identifier.Identifier( - _identifier_model.ResourceType.LAUNCH_PLAN, - project, - domain, - name, - version + _identifier_model.ResourceType.LAUNCH_PLAN, project, domain, name, version ) _engine_loader.get_engine().get_launch_plan(self).register(id_to_register) self._id = id_to_register diff --git a/flytekit/common/mixins/artifact.py b/flytekit/common/mixins/artifact.py index aeb9075b91..89bde8f34a 100644 --- a/flytekit/common/mixins/artifact.py +++ b/flytekit/common/mixins/artifact.py @@ -1,14 +1,16 @@ from __future__ import absolute_import + import abc as _abc import datetime as _datetime -import six as _six import time as _time -from flytekit.models import common as _common_models + +import six as _six + from flytekit.common.exceptions import user as _user_exceptions +from flytekit.models import common as _common_models class ExecutionArtifact(_six.with_metaclass(_common_models.FlyteABCMeta, object)): - @_abc.abstractproperty def inputs(self): """ diff --git a/flytekit/common/mixins/launchable.py b/flytekit/common/mixins/launchable.py index 842f956a5d..97f8d48455 100644 --- a/flytekit/common/mixins/launchable.py +++ b/flytekit/common/mixins/launchable.py @@ -1,13 +1,22 @@ from __future__ import absolute_import + import abc as _abc -import six as _six +import six as _six from deprecated import deprecated as _deprecated class LaunchableEntity(_six.with_metaclass(_abc.ABCMeta, object)): - def launch(self, project, domain, inputs=None, name=None, notification_overrides=None, label_overrides=None, - annotation_overrides=None): + def launch( + self, + project, + domain, + inputs=None, + name=None, + notification_overrides=None, + label_overrides=None, + annotation_overrides=None, + ): """ Creates a remote execution from the entity and returns the execution identifier. This version of launch is meant for when inputs are specified as Python native types/structures. @@ -36,9 +45,17 @@ def launch(self, project, domain, inputs=None, name=None, notification_overrides annotation_overrides=annotation_overrides, ) - @_deprecated(reason="Use launch instead", version='0.9.0') - def execute(self, project, domain, inputs=None, name=None, notification_overrides=None, label_overrides=None, - annotation_overrides=None): + @_deprecated(reason="Use launch instead", version="0.9.0") + def execute( + self, + project, + domain, + inputs=None, + name=None, + notification_overrides=None, + label_overrides=None, + annotation_overrides=None, + ): """ Deprecated. """ @@ -57,8 +74,16 @@ def _python_std_input_map_to_literal_map(self, inputs): pass @_abc.abstractmethod - def launch_with_literals(self, project, domain, literal_inputs, name=None, notification_overrides=None, - label_overrides=None, annotation_overrides=None): + def launch_with_literals( + self, + project, + domain, + literal_inputs, + name=None, + notification_overrides=None, + label_overrides=None, + annotation_overrides=None, + ): """ Executes the entity and returns the execution identifier. This version of execution is meant for when you already have a LiteralMap of inputs. @@ -77,11 +102,20 @@ def launch_with_literals(self, project, domain, literal_inputs, name=None, notif """ pass - @_deprecated(reason="Use launch_with_literals instead", version='0.9.0') - def execute_with_literals(self, project, domain, literal_inputs, name=None, notification_overrides=None, - label_overrides=None, annotation_overrides=None): + @_deprecated(reason="Use launch_with_literals instead", version="0.9.0") + def execute_with_literals( + self, + project, + domain, + literal_inputs, + name=None, + notification_overrides=None, + label_overrides=None, + annotation_overrides=None, + ): """ Deprecated. """ - return self.launch_with_literals(project, domain, literal_inputs, name, notification_overrides, label_overrides, - annotation_overrides) + return self.launch_with_literals( + project, domain, literal_inputs, name, notification_overrides, label_overrides, annotation_overrides, + ) diff --git a/flytekit/common/mixins/registerable.py b/flytekit/common/mixins/registerable.py index ba0a4062fe..9eed525ba1 100644 --- a/flytekit/common/mixins/registerable.py +++ b/flytekit/common/mixins/registerable.py @@ -1,13 +1,15 @@ from __future__ import absolute_import + import abc as _abc -import inspect as _inspect -import six as _six import importlib as _importlib +import inspect as _inspect import logging as _logging +import six as _six + from flytekit.common import sdk_bases as _sdk_bases -from flytekit.common.exceptions import system as _system_exceptions from flytekit.common import utils as _utils +from flytekit.common.exceptions import system as _system_exceptions class _InstanceTracker(_sdk_bases.ExtendedSdkType): @@ -21,12 +23,13 @@ class _InstanceTracker(_sdk_bases.ExtendedSdkType): like to only register a task once and do so with the name where it is defined. This metaclass allows us to do this by inspecting the call stack when __call__ is called on the metaclass (thus instantiating an object). """ + @staticmethod def _find_instance_module(): frame = _inspect.currentframe() while frame: - if frame.f_code.co_name == '': - return frame.f_globals['__name__'] + if frame.f_code.co_name == "": + return frame.f_globals["__name__"] frame = frame.f_back return None @@ -37,7 +40,6 @@ def __call__(cls, *args, **kwargs): class RegisterableEntity(_six.with_metaclass(_InstanceTracker, object)): - def __init__(self, *args, **kwargs): self._platform_valid_name = None super(RegisterableEntity, self).__init__(*args, **kwargs) diff --git a/flytekit/common/nodes.py b/flytekit/common/nodes.py index 37b6b23c91..419455bc61 100644 --- a/flytekit/common/nodes.py +++ b/flytekit/common/nodes.py @@ -6,17 +6,23 @@ import six as _six from sortedcontainers import SortedDict as _SortedDict +from flytekit.common import component_nodes as _component_nodes from flytekit.common import constants as _constants -from flytekit.common import sdk_bases as _sdk_bases, promise as _promise, component_nodes as _component_nodes -from flytekit.common.exceptions import scopes as _exception_scopes, user as _user_exceptions +from flytekit.common import promise as _promise +from flytekit.common import sdk_bases as _sdk_bases +from flytekit.common.exceptions import scopes as _exception_scopes from flytekit.common.exceptions import system as _system_exceptions -from flytekit.common.mixins import hash as _hash_mixin, artifact as _artifact_mixin +from flytekit.common.exceptions import user as _user_exceptions +from flytekit.common.mixins import artifact as _artifact_mixin +from flytekit.common.mixins import hash as _hash_mixin from flytekit.common.tasks import executions as _task_executions from flytekit.common.types import helpers as _type_helpers from flytekit.common.utils import _dnsify from flytekit.engines import loader as _engine_loader -from flytekit.models import common as _common_models, node_execution as _node_execution_models -from flytekit.models.core import workflow as _workflow_model, execution as _execution_models +from flytekit.models import common as _common_models +from flytekit.models import node_execution as _node_execution_models +from flytekit.models.core import execution as _execution_models +from flytekit.models.core import workflow as _workflow_model class ParameterMapper(_six.with_metaclass(_common_models.FlyteABCMeta, _SortedDict)): @@ -62,7 +68,7 @@ def __init__(self, type_map, node): self._initialized = True def __getattr__(self, key): - if key == 'iteritems' and hasattr(super(ParameterMapper, self), 'items'): + if key == "iteritems" and hasattr(super(ParameterMapper, self), "items"): return super(ParameterMapper, self).items if hasattr(super(ParameterMapper, self), key): return getattr(super(ParameterMapper, self), key) @@ -71,7 +77,7 @@ def __getattr__(self, key): return self[key] def __setattr__(self, key, value): - if '_initialized' in self.__dict__: + if "_initialized" in self.__dict__: raise _user_exceptions.FlyteAssertion("Parameters are immutable.") else: super(ParameterMapper, self).__setattr__(key, value) @@ -100,18 +106,17 @@ def _return_mapping_object(self, sdk_node, sdk_type, name): return _promise.NodeOutput(sdk_node, sdk_type, name) -class SdkNode(_six.with_metaclass(_sdk_bases.ExtendedSdkType, _hash_mixin.HashOnReferenceMixin, _workflow_model.Node)): - +class SdkNode(_six.with_metaclass(_sdk_bases.ExtendedSdkType, _hash_mixin.HashOnReferenceMixin, _workflow_model.Node,)): def __init__( - self, - id, - upstream_nodes, - bindings, - metadata, - sdk_task=None, - sdk_workflow=None, - sdk_launch_plan=None, - sdk_branch=None + self, + id, + upstream_nodes, + bindings, + metadata, + sdk_task=None, + sdk_workflow=None, + sdk_launch_plan=None, + sdk_branch=None, ): """ :param Text id: A workflow-level unique identifier that identifies this node in the workflow. "inputs" and @@ -130,15 +135,12 @@ def __init__( :param TODO sdk_branch: TODO """ non_none_entities = [ - entity - for entity in [sdk_workflow, sdk_branch, sdk_launch_plan, sdk_task] if entity is not None + entity for entity in [sdk_workflow, sdk_branch, sdk_launch_plan, sdk_task] if entity is not None ] if len(non_none_entities) != 1: raise _user_exceptions.FlyteAssertion( "An SDK node must have one underlying entity specified at once. Received the following " - "entities: {}".format( - non_none_entities - ) + "entities: {}".format(non_none_entities) ) workflow_node = None @@ -155,7 +157,7 @@ def __init__( output_aliases=[], # TODO: Are aliases a thing in SDK nodes task_node=_component_nodes.SdkTaskNode(sdk_task) if sdk_task else None, workflow_node=workflow_node, - branch_node=sdk_branch.target if sdk_branch else None + branch_node=sdk_branch.target if sdk_branch else None, ) self._upstream = upstream_nodes self._executable_sdk_object = sdk_task or sdk_workflow or sdk_branch or sdk_launch_plan @@ -187,7 +189,8 @@ def promote_from_model(cls, model, sub_workflows, tasks): sdk_task_node = _component_nodes.SdkTaskNode.promote_from_model(model.task_node, tasks) elif model.workflow_node is not None: sdk_workflow_node = _component_nodes.SdkWorkflowNode.promote_from_model( - model.workflow_node, sub_workflows, tasks) + model.workflow_node, sub_workflows, tasks + ) else: raise _system_exceptions.FlyteSystemException("Bad Node model, neither task nor workflow detected") @@ -224,7 +227,8 @@ def promote_from_model(cls, model, sub_workflows, tasks): ) else: raise _system_exceptions.FlyteSystemException( - "Bad SdkWorkflowNode model, both lp and workflow are None") + "Bad SdkWorkflowNode model, both lp and workflow are None" + ) else: raise _system_exceptions.FlyteSystemException("Bad SdkNode model, both task and workflow nodes are empty") @@ -297,9 +301,7 @@ def __repr__(self): class SdkNodeExecution( _six.with_metaclass( - _sdk_bases.ExtendedSdkType, - _node_execution_models.NodeExecution, - _artifact_mixin.ExecutionArtifact + _sdk_bases.ExtendedSdkType, _node_execution_models.NodeExecution, _artifact_mixin.ExecutionArtifact, ) ): def __init__(self, *args, **kwargs): @@ -353,8 +355,9 @@ def outputs(self): :rtype: dict[Text, T] """ if not self.is_complete: - raise _user_exceptions.FlyteAssertion("Please what until the node execution has completed before " - "requesting the outputs.") + raise _user_exceptions.FlyteAssertion( + "Please what until the node execution has completed before " "requesting the outputs." + ) if self.error: raise _user_exceptions.FlyteAssertion("Outputs could not be found because the execution ended in failure.") @@ -372,8 +375,9 @@ def error(self): :rtype: flytekit.models.core.execution.ExecutionError or None """ if not self.is_complete: - raise _user_exceptions.FlyteAssertion("Please what until the node execution has completed before " - "requesting error information.") + raise _user_exceptions.FlyteAssertion( + "Please what until the node execution has completed before " "requesting error information." + ) return self.closure.error @property @@ -396,11 +400,7 @@ def promote_from_model(cls, base_model): :param _node_execution_models.NodeExecution base_model: :rtype: SdkNodeExecution """ - return cls( - closure=base_model.closure, - id=base_model.id, - input_uri=base_model.input_uri - ) + return cls(closure=base_model.closure, id=base_model.id, input_uri=base_model.input_uri) def sync(self): """ diff --git a/flytekit/common/notifications.py b/flytekit/common/notifications.py index f549e8d81b..a9b5fbfc44 100644 --- a/flytekit/common/notifications.py +++ b/flytekit/common/notifications.py @@ -1,9 +1,11 @@ from __future__ import absolute_import -from flytekit.models import common as _common_model -from flytekit.models.core import execution as _execution_model + +import six as _six + from flytekit.common import sdk_bases as _sdk_bases from flytekit.common.exceptions import user as _user_exceptions -import six as _six +from flytekit.models import common as _common_model +from flytekit.models.core import execution as _execution_model class Notification(_six.with_metaclass(_sdk_bases.ExtendedSdkType, _common_model.Notification)): @@ -12,7 +14,7 @@ class Notification(_six.with_metaclass(_sdk_bases.ExtendedSdkType, _common_model _execution_model.WorkflowExecutionPhase.ABORTED, _execution_model.WorkflowExecutionPhase.FAILED, _execution_model.WorkflowExecutionPhase.SUCCEEDED, - _execution_model.WorkflowExecutionPhase.TIMED_OUT + _execution_model.WorkflowExecutionPhase.TIMED_OUT, } def __init__(self, phases, email=None, pager_duty=None, slack=None): @@ -34,7 +36,7 @@ def _validate_phases(self, phases): raise _user_exceptions.FlyteValueException( phase, self.VALID_PHASES, - additional_message="Notifications can only be specified on terminal states." + additional_message="Notifications can only be specified on terminal states.", ) @classmethod @@ -64,10 +66,7 @@ def promote_from_model(cls, base_model): :param flytekit.models.common.Notification base_model: :rtype: Notification """ - return cls( - base_model.phases, - base_model.pager_duty.recipients_email - ) + return cls(base_model.phases, base_model.pager_duty.recipients_email) class Email(Notification): @@ -85,10 +84,7 @@ def promote_from_model(cls, base_model): :param flytekit.models.common.Notification base_model: :rtype: Notification """ - return cls( - base_model.phases, - base_model.email.recipients_email - ) + return cls(base_model.phases, base_model.email.recipients_email) class Slack(Notification): @@ -106,7 +102,4 @@ def promote_from_model(cls, base_model): :param flytekit.models.common.Notification base_model: :rtype: Notification """ - return cls( - base_model.phases, - base_model.slack.recipients_email - ) + return cls(base_model.phases, base_model.slack.recipients_email) diff --git a/flytekit/common/promise.py b/flytekit/common/promise.py index c05db95ec2..9c31b0534c 100644 --- a/flytekit/common/promise.py +++ b/flytekit/common/promise.py @@ -2,14 +2,15 @@ import six as _six -from flytekit.common import constants as _constants, sdk_bases as _sdk_bases +from flytekit.common import constants as _constants +from flytekit.common import sdk_bases as _sdk_bases from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.types import helpers as _type_helpers -from flytekit.models import interface as _interface_models, types as _type_models +from flytekit.models import interface as _interface_models +from flytekit.models import types as _type_models class Input(_six.with_metaclass(_sdk_bases.ExtendedSdkType, _interface_models.Parameter)): - def __init__(self, name, sdk_type, help=None, **kwargs): """ :param Text name: @@ -20,22 +21,22 @@ def __init__(self, name, sdk_type, help=None, **kwargs): :param T default: If this is not a required input, the value will default to this value. """ param_default = None - if 'required' not in kwargs and 'default' not in kwargs: + if "required" not in kwargs and "default" not in kwargs: # Neither required or default is set so assume required required = True default = None - elif kwargs.get('required', False) and 'default' in kwargs: + elif kwargs.get("required", False) and "default" in kwargs: # Required cannot be set to True and have a default specified raise _user_exceptions.FlyteAssertion("Default cannot be set when required is True") - elif 'default' in kwargs: + elif "default" in kwargs: # If default is specified, then required must be false and the value is whatever is specified required = None - default = kwargs['default'] + default = kwargs["default"] param_default = sdk_type.from_python_std(default) else: # If no default is set, but required is set, then the behavior is determined by required == True or False default = None - required = kwargs['required'] + required = kwargs["required"] if not required: # If required == False, we assume default to be None param_default = sdk_type.from_python_std(default) @@ -48,9 +49,9 @@ def __init__(self, name, sdk_type, help=None, **kwargs): self._promise = _type_models.OutputReference(_constants.GLOBAL_INPUT_NODE_ID, name) self._name = name super(Input, self).__init__( - _interface_models.Variable(type=sdk_type.to_flyte_literal_type(), description=help or ''), + _interface_models.Variable(type=sdk_type.to_flyte_literal_type(), description=help or ""), required=required, - default=param_default + default=param_default, ) def rename_and_return_reference(self, new_name): @@ -112,24 +113,12 @@ def promote_from_model(cls, model): if model.default is not None: default_value = sdk_type.from_flyte_idl(model.default.to_flyte_idl()).to_python_std() - return cls( - "", - sdk_type, - help=model.var.description, - required=False, - default=default_value - ) + return cls("", sdk_type, help=model.var.description, required=False, default=default_value,) else: - return cls( - "", - sdk_type, - help=model.var.description, - required=True - ) + return cls("", sdk_type, help=model.var.description, required=True) class NodeOutput(_six.with_metaclass(_sdk_bases.ExtendedSdkType, _type_models.OutputReference)): - def __init__(self, sdk_node, sdk_type, var): """ :param flytekit.common.nodes.SdkNode sdk_node: @@ -138,10 +127,7 @@ def __init__(self, sdk_node, sdk_type, var): """ self._node = sdk_node self._type = sdk_type - super(NodeOutput, self).__init__( - self._node.id, - var - ) + super(NodeOutput, self).__init__(self._node.id, var) @property def node_id(self): @@ -157,8 +143,10 @@ def promote_from_model(cls, model): :param flytekit.models.types.OutputReference model: :rtype: NodeOutput """ - raise _user_exceptions.FlyteAssertion("A NodeOutput cannot be promoted from a protobuf because it must be " - "contextualized by an existing SdkNode.") + raise _user_exceptions.FlyteAssertion( + "A NodeOutput cannot be promoted from a protobuf because it must be " + "contextualized by an existing SdkNode." + ) @property def sdk_node(self): diff --git a/flytekit/common/schedules.py b/flytekit/common/schedules.py index 03618570da..76afcab18b 100644 --- a/flytekit/common/schedules.py +++ b/flytekit/common/schedules.py @@ -1,11 +1,14 @@ from __future__ import absolute_import, division -from flytekit.models import schedule as _schedule_models -from flytekit.common import sdk_bases as _sdk_bases -from flytekit.common.exceptions import user as _user_exceptions -import croniter as _croniter + import datetime as _datetime + +import croniter as _croniter import six as _six +from flytekit.common import sdk_bases as _sdk_bases +from flytekit.common.exceptions import user as _user_exceptions +from flytekit.models import schedule as _schedule_models + class _ExtendedSchedule(_schedule_models.Schedule): @classmethod @@ -18,7 +21,6 @@ def from_flyte_idl(cls, proto): class CronSchedule(_six.with_metaclass(_sdk_bases.ExtendedSdkType, _ExtendedSchedule)): - def __init__(self, cron_expression, kickoff_time_input_arg=None): """ :param Text cron_expression: @@ -41,12 +43,10 @@ def _validate_expression(cron_expression): if len(tokens) != 6: raise _user_exceptions.FlyteAssertion( "Cron expression is invalid. A cron expression must have 6 fields. Cron expressions are in the " - "format of: `minute hour day-of-month month day-of-week year`. Received: `{}`".format( - cron_expression - ) + "format of: `minute hour day-of-month month day-of-week year`. Received: `{}`".format(cron_expression) ) - if tokens[2] != '?' and tokens[4] != '?': + if tokens[2] != "?" and tokens[4] != "?": raise _user_exceptions.FlyteAssertion( "Scheduled string is invalid. A cron expression must have a '?' for either day-of-month or " "day-of-week. Please specify '?' for one of those fields. Cron expressions are in the format of: " @@ -58,13 +58,11 @@ def _validate_expression(cron_expression): try: # Cut to 5 fields and just assume year field is good because croniter treats the 6th field as seconds. # TODO: Parse this field ourselves and check - _croniter.croniter(" ".join(cron_expression.replace('?', '*').split()[:5])) - except: + _croniter.croniter(" ".join(cron_expression.replace("?", "*").split()[:5])) + except Exception: raise _user_exceptions.FlyteAssertion( "Scheduled string is invalid. The cron expression was found to be invalid." - " Provided cron expr: {}".format( - cron_expression - ) + " Provided cron expr: {}".format(cron_expression) ) @classmethod @@ -73,14 +71,10 @@ def promote_from_model(cls, base_model): :param flytekit.models.schedule.Schedule base_model: :rtype: CronSchedule """ - return cls( - base_model.cron_expression, - kickoff_time_input_arg=base_model.kickoff_time_input_arg - ) + return cls(base_model.cron_expression, kickoff_time_input_arg=base_model.kickoff_time_input_arg,) class FixedRate(_six.with_metaclass(_sdk_bases.ExtendedSdkType, _ExtendedSchedule)): - def __init__(self, duration, kickoff_time_input_arg=None): """ :param datetime.timedelta duration: @@ -106,18 +100,15 @@ def _translate_duration(duration): ) elif int(duration.total_seconds()) % _SECONDS_TO_DAYS == 0: return _schedule_models.Schedule.FixedRate( - int(duration.total_seconds() / _SECONDS_TO_DAYS), - _schedule_models.Schedule.FixedRateUnit.DAY + int(duration.total_seconds() / _SECONDS_TO_DAYS), _schedule_models.Schedule.FixedRateUnit.DAY, ) elif int(duration.total_seconds()) % _SECONDS_TO_HOURS == 0: return _schedule_models.Schedule.FixedRate( - int(duration.total_seconds() / _SECONDS_TO_HOURS), - _schedule_models.Schedule.FixedRateUnit.HOUR + int(duration.total_seconds() / _SECONDS_TO_HOURS), _schedule_models.Schedule.FixedRateUnit.HOUR, ) else: return _schedule_models.Schedule.FixedRate( - int(duration.total_seconds() / _SECONDS_TO_MINUTES), - _schedule_models.Schedule.FixedRateUnit.MINUTE + int(duration.total_seconds() / _SECONDS_TO_MINUTES), _schedule_models.Schedule.FixedRateUnit.MINUTE, ) @classmethod @@ -133,7 +124,4 @@ def promote_from_model(cls, base_model): else: duration = _datetime.timedelta(minutes=base_model.rate.value) - return cls( - duration, - kickoff_time_input_arg=base_model.kickoff_time_input_arg - ) + return cls(duration, kickoff_time_input_arg=base_model.kickoff_time_input_arg) diff --git a/flytekit/common/sdk_bases.py b/flytekit/common/sdk_bases.py index 11e96f13b1..743ee072ff 100644 --- a/flytekit/common/sdk_bases.py +++ b/flytekit/common/sdk_bases.py @@ -1,8 +1,11 @@ from __future__ import absolute_import -from flytekit.models import common as _common + import abc as _abc + import six as _six +from flytekit.models import common as _common + class ExtendedSdkType(_six.with_metaclass(_common.FlyteABCMeta, _common.FlyteType)): """ diff --git a/flytekit/common/tasks/executions.py b/flytekit/common/tasks/executions.py index f77728a3c4..9b380ada89 100644 --- a/flytekit/common/tasks/executions.py +++ b/flytekit/common/tasks/executions.py @@ -1,4 +1,7 @@ from __future__ import absolute_import + +import six as _six + from flytekit.common import sdk_bases as _sdk_bases from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.mixins import artifact as _artifact_mixin @@ -6,14 +9,11 @@ from flytekit.engines import loader as _engine_loader from flytekit.models.admin import task_execution as _task_execution_model from flytekit.models.core import execution as _execution_models -import six as _six class SdkTaskExecution( _six.with_metaclass( - _sdk_bases.ExtendedSdkType, - _task_execution_model.TaskExecution, - _artifact_mixin.ExecutionArtifact, + _sdk_bases.ExtendedSdkType, _task_execution_model.TaskExecution, _artifact_mixin.ExecutionArtifact, ) ): def __init__(self, *args, **kwargs): @@ -55,8 +55,9 @@ def outputs(self): :rtype: dict[Text, T] """ if not self.is_complete: - raise _user_exceptions.FlyteAssertion("Please what until the task execution has completed before " - "requesting the outputs.") + raise _user_exceptions.FlyteAssertion( + "Please what until the task execution has completed before " "requesting the outputs." + ) if self.error: raise _user_exceptions.FlyteAssertion("Outputs could not be found because the execution ended in failure.") @@ -74,8 +75,9 @@ def error(self): :rtype: flytekit.models.core.execution.ExecutionError or None """ if not self.is_complete: - raise _user_exceptions.FlyteAssertion("Please what until the task execution has completed before " - "requesting error information.") + raise _user_exceptions.FlyteAssertion( + "Please what until the task execution has completed before " "requesting error information." + ) return self.closure.error def get_child_executions(self, filters=None): @@ -84,13 +86,11 @@ def get_child_executions(self, filters=None): :rtype: dict[Text, flytekit.common.nodes.SdkNodeExecution] """ from flytekit.common import nodes as _nodes + if not self.is_parent: raise _user_exceptions.FlyteAssertion("Only task executions marked with 'is_parent' have child executions.") models = _engine_loader.get_engine().get_task_execution(self).get_child_executions(filters=filters) - return { - k: _nodes.SdkNodeExecution.promote_from_model(v) - for k, v in _six.iteritems(models) - } + return {k: _nodes.SdkNodeExecution.promote_from_model(v) for k, v in _six.iteritems(models)} @classmethod def promote_from_model(cls, base_model): diff --git a/flytekit/common/tasks/generic_spark_task.py b/flytekit/common/tasks/generic_spark_task.py index 2a53816cb3..03d00b00e6 100644 --- a/flytekit/common/tasks/generic_spark_task.py +++ b/flytekit/common/tasks/generic_spark_task.py @@ -1,54 +1,53 @@ from __future__ import absolute_import -try: - from inspect import getfullargspec as _getargspec -except ImportError: - from inspect import getargspec as _getargspec - -from flytekit import __version__ import sys as _sys -import six as _six -from flytekit.common.tasks import task as _base_tasks -from flytekit.common.types import helpers as _helpers, primitives as _primitives -from flytekit.models import literals as _literal_models, task as _task_models +import six as _six from google.protobuf.json_format import MessageToDict as _MessageToDict + +from flytekit import __version__ from flytekit.common import interface as _interface -from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.exceptions import scopes as _exception_scopes - +from flytekit.common.exceptions import user as _user_exceptions +from flytekit.common.tasks import task as _base_tasks +from flytekit.common.types import helpers as _helpers +from flytekit.common.types import primitives as _primitives from flytekit.configuration import internal as _internal_config +from flytekit.models import literals as _literal_models +from flytekit.models import task as _task_models -input_types_supported = { _primitives.Integer, - _primitives.Boolean, - _primitives.Float, - _primitives.String, - _primitives.Datetime, - _primitives.Timedelta, - } +input_types_supported = { + _primitives.Integer, + _primitives.Boolean, + _primitives.Float, + _primitives.String, + _primitives.Datetime, + _primitives.Timedelta, +} -class SdkGenericSparkTask( _base_tasks.SdkTask): +class SdkGenericSparkTask(_base_tasks.SdkTask): """ This class includes the additional logic for building a task that executes as a Spark Job. """ + def __init__( - self, - task_type, - discovery_version, - retries, - interruptible, - task_inputs, - deprecated, - discoverable, - timeout, - spark_type, - main_class, - main_application_file, - spark_conf, - hadoop_conf, - environment, + self, + task_type, + discovery_version, + retries, + interruptible, + task_inputs, + deprecated, + discoverable, + timeout, + spark_type, + main_class, + main_application_file, + spark_conf, + hadoop_conf, + environment, ): """ :param Text task_type: string describing the task type @@ -69,7 +68,7 @@ def __init__( spark_job = _task_models.SparkJob( spark_conf=spark_conf, hadoop_conf=hadoop_conf, - spark_type = spark_type, + spark_type=spark_type, application_file=main_application_file, main_class=main_class, executor_path=_sys.executable, @@ -79,16 +78,12 @@ def __init__( task_type, _task_models.TaskMetadata( discoverable, - _task_models.RuntimeMetadata( - _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, - __version__, - 'spark' - ), + _task_models.RuntimeMetadata(_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, "spark",), timeout, _literal_models.RetryStrategy(retries), interruptible, discovery_version, - deprecated + deprecated, ), _interface.TypedInterface({}, {}), _MessageToDict(spark_job), @@ -99,9 +94,7 @@ def __init__( task_inputs(self) # Container after the Inputs have been updated. - self._container = self._get_container_definition( - environment=environment - ) + self._container = self._get_container_definition(environment=environment) def _validate_inputs(self, inputs): """ @@ -109,10 +102,12 @@ def _validate_inputs(self, inputs): :raises: flytekit.common.exceptions.user.FlyteValidationException """ for k, v in _six.iteritems(inputs): - sdk_type =_helpers.get_sdk_type_from_literal_type(v.type) + sdk_type = _helpers.get_sdk_type_from_literal_type(v.type) if sdk_type not in input_types_supported: raise _user_exceptions.FlyteValidationException( - "Input Type '{}' not supported. Only Primitives are supported for Scala/Java Spark.".format(sdk_type) + "Input Type '{}' not supported. Only Primitives are supported for Scala/Java Spark.".format( + sdk_type + ) ) super(SdkGenericSparkTask, self)._validate_inputs(inputs) @@ -128,8 +123,7 @@ def add_inputs(self, inputs): self.interface.inputs.update(inputs) def _get_container_definition( - self, - environment=None, + self, environment=None, ): """ :rtype: Container @@ -146,5 +140,5 @@ def _get_container_definition( args=args, resources=_task_models.Resources([], []), env=environment, - config={} + config={}, ) diff --git a/flytekit/common/tasks/hive_task.py b/flytekit/common/tasks/hive_task.py index 528c260796..42436dc61b 100644 --- a/flytekit/common/tasks/hive_task.py +++ b/flytekit/common/tasks/hive_task.py @@ -5,19 +5,20 @@ import six as _six from google.protobuf.json_format import MessageToDict as _MessageToDict -from flytekit.common import constants as _constants, nodes as _nodes, interface as _interface +from flytekit.common import constants as _constants +from flytekit.common import interface as _interface +from flytekit.common import nodes as _nodes from flytekit.common.exceptions import scopes as _exception_scopes -from flytekit.common.exceptions.user import FlyteTypeException as _FlyteTypeException, \ - FlyteValueException as _FlyteValueException +from flytekit.common.exceptions.user import FlyteTypeException as _FlyteTypeException +from flytekit.common.exceptions.user import FlyteValueException as _FlyteValueException from flytekit.common.tasks import output as _task_output -from flytekit.common.tasks import sdk_runnable as _sdk_runnable, task as _base_task +from flytekit.common.tasks import sdk_runnable as _sdk_runnable +from flytekit.common.tasks import task as _base_task from flytekit.common.types import helpers as _type_helpers -from flytekit.models import ( - qubole as _qubole, - interface as _interface_model, - literals as _literal_models, - dynamic_job as _dynamic_job -) +from flytekit.models import dynamic_job as _dynamic_job +from flytekit.models import interface as _interface_model +from flytekit.models import literals as _literal_models +from flytekit.models import qubole as _qubole from flytekit.models.core import workflow as _workflow_model ALLOWED_TAGS_COUNT = int(6) @@ -30,26 +31,26 @@ class SdkHiveTask(_sdk_runnable.SdkRunnableTask): """ def __init__( - self, - task_function, - task_type, - discovery_version, - retries, - interruptible, - deprecated, - storage_request, - cpu_request, - gpu_request, - memory_request, - storage_limit, - cpu_limit, - gpu_limit, - memory_limit, - discoverable, - timeout, - cluster_label, - tags, - environment + self, + task_function, + task_type, + discovery_version, + retries, + interruptible, + deprecated, + storage_request, + cpu_request, + gpu_request, + memory_request, + storage_limit, + cpu_limit, + gpu_limit, + memory_limit, + discoverable, + timeout, + cluster_label, + tags, + environment, ): """ :param task_function: Function container user code. This will be executed via the SDK's engine. @@ -72,9 +73,26 @@ def __init__( :param dict[Text, Text] environment: """ self._task_function = task_function - super(SdkHiveTask, self).__init__(task_function, task_type, discovery_version, retries, interruptible, deprecated, - storage_request, cpu_request, gpu_request, memory_request, storage_limit, - cpu_limit, gpu_limit, memory_limit, discoverable, timeout, environment, {}) + super(SdkHiveTask, self).__init__( + task_function, + task_type, + discovery_version, + retries, + interruptible, + deprecated, + storage_request, + cpu_request, + gpu_request, + memory_request, + storage_limit, + cpu_limit, + gpu_limit, + memory_limit, + discoverable, + timeout, + environment, + {}, + ) self._validate_task_parameters(cluster_label, tags) self._cluster_label = cluster_label self._tags = tags @@ -94,8 +112,9 @@ def _generate_plugin_objects(self, context, inputs_dict): plugin_objects = [] for q in queries_from_task: - hive_query = _qubole.HiveQuery(query=q, timeout_sec=self.metadata.timeout.seconds, - retry_count=self.metadata.retries.retries) + hive_query = _qubole.HiveQuery( + query=q, timeout_sec=self.metadata.timeout.seconds, retry_count=self.metadata.retries.retries, + ) # TODO: Remove this after all users of older SDK versions that did the single node, multi-query pattern are # deprecated. This is only here for backwards compatibility - in addition to writing the query to the @@ -103,8 +122,9 @@ def _generate_plugin_objects(self, context, inputs_dict): # older plugin will continue to work. query_collection = _qubole.HiveQueryCollection([hive_query]) - plugin_objects.append(_qubole.QuboleHiveJob(hive_query, self._cluster_label, self._tags, - query_collection=query_collection)) + plugin_objects.append( + _qubole.QuboleHiveJob(hive_query, self._cluster_label, self._tags, query_collection=query_collection,) + ) return plugin_objects @@ -115,7 +135,7 @@ def _validate_task_parameters(cluster_label, tags): type(cluster_label), {str, _six.text_type}, additional_msg="cluster_label for a hive task must be in text format", - received_value=cluster_label + received_value=cluster_label, ) if tags is not None: if not (isinstance(tags, list) and all(isinstance(tag, (str, _six.text_type)) for tag in tags)): @@ -123,12 +143,16 @@ def _validate_task_parameters(cluster_label, tags): type(tags), [], additional_msg="tags for a hive task must be in 'list of text' format", - received_value=tags + received_value=tags, ) if len(tags) > ALLOWED_TAGS_COUNT: - raise _FlyteValueException(len(tags), "number of tags must be less than {}".format(ALLOWED_TAGS_COUNT)) + raise _FlyteValueException( + len(tags), "number of tags must be less than {}".format(ALLOWED_TAGS_COUNT), + ) if not all(len(tag) for tag in tags): - raise _FlyteValueException(tags, "length of a tag must be less than {} chars".format(MAX_TAG_LENGTH)) + raise _FlyteValueException( + tags, "length of a tag must be less than {} chars".format(MAX_TAG_LENGTH), + ) @staticmethod def _validate_queries(queries_from_task): @@ -138,7 +162,7 @@ def _validate_queries(queries_from_task): type(query_from_task), {str, _six.text_type}, additional_msg="All queries returned from a Hive task must be in text format.", - received_value=query_from_task + received_value=query_from_task, ) def _produce_dynamic_job_spec(self, context, inputs): @@ -148,9 +172,10 @@ def _produce_dynamic_job_spec(self, context, inputs): :param flytekit.models.literals.LiteralMap literal_map inputs: :rtype: flytekit.models.dynamic_job.DynamicJobSpec """ - inputs_dict = _type_helpers.unpack_literal_map_to_sdk_python_std(inputs, { - k: _type_helpers.get_sdk_type_from_literal_type(v.type) for k, v in _six.iteritems(self.interface.inputs) - }) + inputs_dict = _type_helpers.unpack_literal_map_to_sdk_python_std( + inputs, + {k: _type_helpers.get_sdk_type_from_literal_type(v.type) for k, v in _six.iteritems(self.interface.inputs)}, + ) outputs_dict = { name: _task_output.OutputReference(_type_helpers.get_sdk_type_from_literal_type(variable.type)) for name, variable in _six.iteritems(self.interface.outputs) @@ -165,27 +190,23 @@ def _produce_dynamic_job_spec(self, context, inputs): generated_queries = self._generate_plugin_objects(context, inputs_dict) # Create output bindings always - this has to happen after user code has run - output_bindings = [_literal_models.Binding(var=name, binding=_interface.BindingData.from_python_std( - b.sdk_type.to_flyte_literal_type(), b.value)) - for name, b in _six.iteritems(outputs_dict)] + output_bindings = [ + _literal_models.Binding( + var=name, binding=_interface.BindingData.from_python_std(b.sdk_type.to_flyte_literal_type(), b.value), + ) + for name, b in _six.iteritems(outputs_dict) + ] i = 0 for quboleHiveJob in generated_queries: - hive_job_node = _create_hive_job_node( - "HiveQuery_{}".format(i), - quboleHiveJob.to_flyte_idl(), - self.metadata - ) + hive_job_node = _create_hive_job_node("HiveQuery_{}".format(i), quboleHiveJob.to_flyte_idl(), self.metadata) nodes.append(hive_job_node) tasks.append(hive_job_node.executable_sdk_object) i += 1 dynamic_job_spec = _dynamic_job.DynamicJobSpec( - min_successes=len(nodes), - tasks=tasks, - nodes=nodes, - outputs=output_bindings, - subworkflows=[]) + min_successes=len(nodes), tasks=tasks, nodes=nodes, outputs=output_bindings, subworkflows=[], + ) return dynamic_job_spec @@ -211,12 +232,11 @@ def execute(self, context, inputs): if len(spec.nodes) == 0: return { _constants.OUTPUT_FILE_NAME: _literal_models.LiteralMap( - literals={binding.var: binding.binding.to_literal_model() for binding in spec.outputs}) + literals={binding.var: binding.binding.to_literal_model() for binding in spec.outputs} + ) } else: - generated_files.update({ - _constants.FUTURES_FILE_NAME: spec - }) + generated_files.update({_constants.FUTURES_FILE_NAME: spec}) return generated_files @@ -234,7 +254,7 @@ def _create_hive_job_node(name, hive_job, metadata): upstream_nodes=[], bindings=[], metadata=_workflow_model.NodeMetadata(name, metadata.timeout, _literal_models.RetryStrategy(0)), - sdk_task=SdkHiveJob(hive_job, metadata) + sdk_task=SdkHiveJob(hive_job, metadata), ) @@ -245,9 +265,7 @@ class SdkHiveJob(_base_task.SdkTask): """ def __init__( - self, - hive_job, - metadata, + self, hive_job, metadata, ): """ :param _qubole.QuboleHiveJob hive_job: Hive job spec diff --git a/flytekit/common/tasks/presto_task.py b/flytekit/common/tasks/presto_task.py index dfc4b5b47f..9612f551a9 100644 --- a/flytekit/common/tasks/presto_task.py +++ b/flytekit/common/tasks/presto_task.py @@ -1,24 +1,21 @@ from __future__ import absolute_import -import six as _six - +import datetime as _datetime +import six as _six from google.protobuf.json_format import MessageToDict as _MessageToDict -from flytekit import __version__ +from flytekit import __version__ from flytekit.common import constants as _constants -from flytekit.common.tasks import task as _base_task -from flytekit.models import ( - interface as _interface_model -) -from flytekit.models import literals as _literals, types as _types, \ - task as _task_model - from flytekit.common import interface as _interface -import datetime as _datetime -from flytekit.models import presto as _presto_models -from flytekit.common.types import helpers as _type_helpers from flytekit.common.exceptions import scopes as _exception_scopes +from flytekit.common.tasks import task as _base_task +from flytekit.common.types import helpers as _type_helpers +from flytekit.models import interface as _interface_model +from flytekit.models import literals as _literals +from flytekit.models import presto as _presto_models +from flytekit.models import task as _task_model +from flytekit.models import types as _types class SdkPrestoTask(_base_task.SdkTask): @@ -27,19 +24,19 @@ class SdkPrestoTask(_base_task.SdkTask): """ def __init__( - self, - statement, - output_schema, - routing_group=None, - catalog=None, - schema=None, - task_inputs=None, - interruptible=False, - discoverable=False, - discovery_version=None, - retries=1, - timeout=None, - deprecated=None + self, + statement, + output_schema, + routing_group=None, + catalog=None, + schema=None, + task_inputs=None, + interruptible=False, + discoverable=False, + discovery_version=None, + retries=1, + timeout=None, + deprecated=None, ): """ :param Text statement: Presto query specification @@ -65,21 +62,16 @@ def __init__( metadata = _task_model.TaskMetadata( discoverable, # This needs to have the proper version reflected in it - _task_model.RuntimeMetadata( - _task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, - "python"), + _task_model.RuntimeMetadata(_task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, "python"), timeout or _datetime.timedelta(seconds=0), _literals.RetryStrategy(retries), interruptible, discovery_version, - deprecated + deprecated, ) presto_query = _presto_models.PrestoQuery( - routing_group=routing_group or "", - catalog=catalog or "", - schema=schema or "", - statement=statement + routing_group=routing_group or "", catalog=catalog or "", schema=schema or "", statement=statement, ) # Here we set the routing_group, catalog, and schema as implicit @@ -88,30 +80,28 @@ def __init__( { "__implicit_routing_group": _interface_model.Variable( type=_types.LiteralType(simple=_types.SimpleType.STRING), - description="The routing group set as an implicit input" + description="The routing group set as an implicit input", ), "__implicit_catalog": _interface_model.Variable( type=_types.LiteralType(simple=_types.SimpleType.STRING), - description="The catalog set as an implicit input" + description="The catalog set as an implicit input", ), "__implicit_schema": _interface_model.Variable( type=_types.LiteralType(simple=_types.SimpleType.STRING), - description="The schema set as an implicit input" - ) + description="The schema set as an implicit input", + ), }, { # Set the schema for the Presto query as an output "results": _interface_model.Variable( type=_types.LiteralType(schema=output_schema.schema_type), - description="The schema for the Presto query" + description="The schema for the Presto query", ) - }) + }, + ) super(SdkPrestoTask, self).__init__( - _constants.SdkTaskType.PRESTO_TASK, - metadata, - i, - _MessageToDict(presto_query.to_flyte_idl()), + _constants.SdkTaskType.PRESTO_TASK, metadata, i, _MessageToDict(presto_query.to_flyte_idl()), ) # Set user provided inputs @@ -132,9 +122,7 @@ def _add_implicit_inputs(self, inputs): def __call__(self, *args, **kwargs): kwargs = self._add_implicit_inputs(kwargs) - return super(SdkPrestoTask, self).__call__( - *args, **kwargs - ) + return super(SdkPrestoTask, self).__call__(*args, **kwargs) # Override method in order to set the implicit inputs def _python_std_input_map_to_literal_map(self, inputs): @@ -144,10 +132,10 @@ def _python_std_input_map_to_literal_map(self, inputs): :rtype: flytekit.models.literals.LiteralMap """ inputs = self._add_implicit_inputs(inputs) - return _type_helpers.pack_python_std_map_to_literal_map(inputs, { - k: _type_helpers.get_sdk_type_from_literal_type(v.type) - for k, v in _six.iteritems(self.interface.inputs) - }) + return _type_helpers.pack_python_std_map_to_literal_map( + inputs, + {k: _type_helpers.get_sdk_type_from_literal_type(v.type) for k, v in _six.iteritems(self.interface.inputs)}, + ) @_exception_scopes.system_entry_point def add_inputs(self, inputs): diff --git a/flytekit/common/tasks/pytorch_task.py b/flytekit/common/tasks/pytorch_task.py index eabb88d2cf..d728db2895 100644 --- a/flytekit/common/tasks/pytorch_task.py +++ b/flytekit/common/tasks/pytorch_task.py @@ -1,21 +1,12 @@ from __future__ import absolute_import -try: - from inspect import getfullargspec as _getargspec -except ImportError: - from inspect import getargspec as _getargspec - -import six as _six -from flytekit.common import constants as _constants -from flytekit.common.exceptions import scopes as _exception_scopes -from flytekit.common.tasks import output as _task_output, sdk_runnable as _sdk_runnable -from flytekit.common.types import helpers as _type_helpers -from flytekit.models import literals as _literal_models, task as _task_models from google.protobuf.json_format import MessageToDict as _MessageToDict +from flytekit.common.tasks import sdk_runnable as _sdk_runnable +from flytekit.models import task as _task_models -class SdkRunnablePytorchContainer(_sdk_runnable.SdkRunnableContainer): +class SdkRunnablePytorchContainer(_sdk_runnable.SdkRunnableContainer): @property def args(self): """ @@ -24,31 +15,30 @@ def args(self): """ return self._args + class SdkPyTorchTask(_sdk_runnable.SdkRunnableTask): def __init__( - self, - task_function, - task_type, - discovery_version, - retries, - interruptible, - deprecated, - discoverable, - timeout, - workers_count, - per_replica_storage_request, - per_replica_cpu_request, - per_replica_gpu_request, - per_replica_memory_request, - per_replica_storage_limit, - per_replica_cpu_limit, - per_replica_gpu_limit, - per_replica_memory_limit, - environment + self, + task_function, + task_type, + discovery_version, + retries, + interruptible, + deprecated, + discoverable, + timeout, + workers_count, + per_replica_storage_request, + per_replica_cpu_request, + per_replica_gpu_request, + per_replica_memory_request, + per_replica_storage_limit, + per_replica_cpu_limit, + per_replica_gpu_limit, + per_replica_memory_limit, + environment, ): - pytorch_job = _task_models.PyTorchJob( - workers_count=workers_count - ).to_flyte_idl() + pytorch_job = _task_models.PyTorchJob(workers_count=workers_count).to_flyte_idl() super(SdkPyTorchTask, self).__init__( task_function=task_function, task_type=task_type, @@ -67,13 +57,10 @@ def __init__( discoverable=discoverable, timeout=timeout, environment=environment, - custom=_MessageToDict(pytorch_job) + custom=_MessageToDict(pytorch_job), ) - def _get_container_definition( - self, - **kwargs - ): + def _get_container_definition(self, **kwargs): """ :rtype: SdkRunnablePytorchContainer """ diff --git a/flytekit/common/tasks/raw_container.py b/flytekit/common/tasks/raw_container.py index a352a7f63a..fd185e6094 100644 --- a/flytekit/common/tasks/raw_container.py +++ b/flytekit/common/tasks/raw_container.py @@ -10,7 +10,8 @@ from flytekit.common.tasks import task as _base_task from flytekit.common.types.base_sdk_types import FlyteSdkType from flytekit.configuration import resources as _resource_config -from flytekit.models import literals as _literals, task as _task_models +from flytekit.models import literals as _literals +from flytekit.models import task as _task_models from flytekit.models.interface import Variable @@ -23,19 +24,19 @@ def types_to_variable(t: Dict[str, FlyteSdkType]) -> Dict[str, Variable]: def _get_container_definition( - image: str, - command: List[str], - args: List[str], - data_loading_config: _task_models.DataLoadingConfig, - storage_request: str = None, - cpu_request: str = None, - gpu_request: str = None, - memory_request: str = None, - storage_limit: str = None, - cpu_limit: str = None, - gpu_limit: str = None, - memory_limit: str = None, - environment: Dict[str, str] = None, + image: str, + command: List[str], + args: List[str], + data_loading_config: _task_models.DataLoadingConfig, + storage_request: str = None, + cpu_request: str = None, + gpu_request: str = None, + memory_request: str = None, + storage_limit: str = None, + cpu_limit: str = None, + gpu_limit: str = None, + memory_limit: str = None, + environment: Dict[str, str] = None, ) -> _task_models.Container: storage_limit = storage_limit or _resource_config.DEFAULT_STORAGE_LIMIT.get() storage_request = storage_request or _resource_config.DEFAULT_STORAGE_REQUEST.get() @@ -49,62 +50,26 @@ def _get_container_definition( requests = [] if storage_request: requests.append( - _task_models.Resources.ResourceEntry( - _task_models.Resources.ResourceName.STORAGE, - storage_request - ) + _task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.STORAGE, storage_request) ) if cpu_request: - requests.append( - _task_models.Resources.ResourceEntry( - _task_models.Resources.ResourceName.CPU, - cpu_request - ) - ) + requests.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.CPU, cpu_request)) if gpu_request: - requests.append( - _task_models.Resources.ResourceEntry( - _task_models.Resources.ResourceName.GPU, - gpu_request - ) - ) + requests.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.GPU, gpu_request)) if memory_request: requests.append( - _task_models.Resources.ResourceEntry( - _task_models.Resources.ResourceName.MEMORY, - memory_request - ) + _task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.MEMORY, memory_request) ) limits = [] if storage_limit: - limits.append( - _task_models.Resources.ResourceEntry( - _task_models.Resources.ResourceName.STORAGE, - storage_limit - ) - ) + limits.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.STORAGE, storage_limit)) if cpu_limit: - limits.append( - _task_models.Resources.ResourceEntry( - _task_models.Resources.ResourceName.CPU, - cpu_limit - ) - ) + limits.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.CPU, cpu_limit)) if gpu_limit: - limits.append( - _task_models.Resources.ResourceEntry( - _task_models.Resources.ResourceName.GPU, - gpu_limit - ) - ) + limits.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.GPU, gpu_limit)) if memory_limit: - limits.append( - _task_models.Resources.ResourceEntry( - _task_models.Resources.ResourceName.MEMORY, - memory_limit - ) - ) + limits.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.MEMORY, memory_limit)) if environment is None: environment = {} @@ -125,35 +90,36 @@ class SdkRawContainerTask(_base_task.SdkTask): Use this task when you want to run an arbitrary container as a task (e.g. external tools, binaries compiled separately as a container completely separate from the container where your Flyte workflow is defined. """ + METADATA_FORMAT_JSON = _task_models.DataLoadingConfig.LITERALMAP_FORMAT_JSON METADATA_FORMAT_YAML = _task_models.DataLoadingConfig.LITERALMAP_FORMAT_YAML METADATA_FORMAT_PROTO = _task_models.DataLoadingConfig.LITERALMAP_FORMAT_PROTO def __init__( - self, - inputs: Dict[str, FlyteSdkType], - image: str, - outputs: Dict[str, FlyteSdkType] = None, - input_data_dir: str = None, - output_data_dir: str = None, - metadata_format: int = METADATA_FORMAT_JSON, - io_strategy: _task_models.IOStrategy=None, - command: List[str] = None, - args: List[str] = None, - storage_request: str = None, - cpu_request: str = None, - gpu_request: str = None, - memory_request: str = None, - storage_limit: str = None, - cpu_limit: str = None, - gpu_limit: str = None, - memory_limit: str = None, - environment: Dict[str, str] = None, - interruptible: bool = False, - discoverable: bool = False, - discovery_version: str = None, - retries: int = 1, - timeout: _datetime.timedelta = None, + self, + inputs: Dict[str, FlyteSdkType], + image: str, + outputs: Dict[str, FlyteSdkType] = None, + input_data_dir: str = None, + output_data_dir: str = None, + metadata_format: int = METADATA_FORMAT_JSON, + io_strategy: _task_models.IOStrategy = None, + command: List[str] = None, + args: List[str] = None, + storage_request: str = None, + cpu_request: str = None, + gpu_request: str = None, + memory_request: str = None, + storage_limit: str = None, + cpu_limit: str = None, + gpu_limit: str = None, + memory_limit: str = None, + environment: Dict[str, str] = None, + interruptible: bool = False, + discoverable: bool = False, + discovery_version: str = None, + retries: int = 1, + timeout: _datetime.timedelta = None, ): """ :param inputs: @@ -193,14 +159,12 @@ def __init__( metadata = _task_models.TaskMetadata( discoverable, # This needs to have the proper version reflected in it - _task_models.RuntimeMetadata( - _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, - "python"), + _task_models.RuntimeMetadata(_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, "python",), timeout or _datetime.timedelta(seconds=0), _literals.RetryStrategy(retries), interruptible, discovery_version, - None + None, ) # The interface is defined using the inputs and outputs @@ -226,10 +190,9 @@ def __init__( gpu_limit=gpu_limit, memory_limit=memory_limit, environment=environment, - ) + ), ) - @_exception_scopes.system_entry_point def add_inputs(self, inputs: Dict[str, Variable]): """ diff --git a/flytekit/common/tasks/sagemaker/hpo_job_task.py b/flytekit/common/tasks/sagemaker/hpo_job_task.py index 6427bdaab2..ad5d7e1291 100644 --- a/flytekit/common/tasks/sagemaker/hpo_job_task.py +++ b/flytekit/common/tasks/sagemaker/hpo_job_task.py @@ -20,15 +20,14 @@ class SdkSimpleHyperparameterTuningJobTask(_sdk_task.SdkTask): - def __init__( - self, - max_number_of_training_jobs: int, - max_parallel_training_jobs: int, - training_job: SdkBuiltinAlgorithmTrainingJobTask, - retries: int = 0, - cacheable: bool = False, - cache_version: str = "", + self, + max_number_of_training_jobs: int, + max_parallel_training_jobs: int, + training_job: SdkBuiltinAlgorithmTrainingJobTask, + retries: int = 0, + cacheable: bool = False, + cache_version: str = "", ): """ @@ -54,20 +53,17 @@ def __init__( timeout = _datetime.timedelta(seconds=0) inputs = { - "hyperparameter_tuning_job_config": _interface_model.Variable( - _sdk_types.Types.Proto( - _pb2_hpo_job.HyperparameterTuningJobConfig).to_flyte_literal_type(), "" - ), - } + "hyperparameter_tuning_job_config": _interface_model.Variable( + _sdk_types.Types.Proto(_pb2_hpo_job.HyperparameterTuningJobConfig).to_flyte_literal_type(), "", + ), + } inputs.update(training_job.interface.inputs) super(SdkSimpleHyperparameterTuningJobTask, self).__init__( type=SdkTaskType.SAGEMAKER_HYPERPARAMETER_TUNING_JOB_TASK, metadata=_task_models.TaskMetadata( runtime=_task_models.RuntimeMetadata( - type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, - version=__version__, - flavor='sagemaker' + type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, version=__version__, flavor="sagemaker", ), discoverable=cacheable, timeout=timeout, @@ -82,13 +78,12 @@ def __init__( "model": _interface_model.Variable( type=_types_models.LiteralType( blob=_core_types.BlobType( - format="", - dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE + format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, ) ), - description="" + description="", ) - } + }, ), custom=MessageToDict(hpo_job), ) diff --git a/flytekit/common/tasks/sagemaker/training_job_task.py b/flytekit/common/tasks/sagemaker/training_job_task.py index 0b4cf35234..446ed738fc 100644 --- a/flytekit/common/tasks/sagemaker/training_job_task.py +++ b/flytekit/common/tasks/sagemaker/training_job_task.py @@ -26,12 +26,12 @@ def _content_type_to_blob_format(content_type: _training_job_models) -> str: class SdkBuiltinAlgorithmTrainingJobTask(_sdk_task.SdkTask): def __init__( - self, - training_job_resource_config: _training_job_models.TrainingJobResourceConfig, - algorithm_specification: _training_job_models.AlgorithmSpecification, - retries: int = 0, - cacheable: bool = False, - cache_version: str = "", + self, + training_job_resource_config: _training_job_models.TrainingJobResourceConfig, + algorithm_specification: _training_job_models.AlgorithmSpecification, + retries: int = 0, + cacheable: bool = False, + cache_version: str = "", ): """ @@ -43,8 +43,7 @@ def __init__( """ # Use the training job model as a measure of type checking self._training_job_model = _training_job_models.TrainingJob( - algorithm_specification=algorithm_specification, - training_job_resource_config=training_job_resource_config, + algorithm_specification=algorithm_specification, training_job_resource_config=training_job_resource_config, ) # Setting flyte-level timeout to 0, and let SageMaker takes the StoppingCondition and terminate the training @@ -55,9 +54,7 @@ def __init__( type=SdkTaskType.SAGEMAKER_TRAINING_JOB_TASK, metadata=_task_models.TaskMetadata( runtime=_task_models.RuntimeMetadata( - type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, - version=__version__, - flavor='sagemaker' + type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, version=__version__, flavor="sagemaker", ), discoverable=cacheable, timeout=timeout, @@ -69,14 +66,13 @@ def __init__( interface=_interface.TypedInterface( inputs={ "static_hyperparameters": _interface_model.Variable( - type=_idl_types.LiteralType(simple=_idl_types.SimpleType.STRUCT), - description="", + type=_idl_types.LiteralType(simple=_idl_types.SimpleType.STRUCT), description="", ), "train": _interface_model.Variable( type=_idl_types.LiteralType( blob=_core_types.BlobType( format=_content_type_to_blob_format(algorithm_specification.input_content_type), - dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART + dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, ), ), description="", @@ -85,7 +81,7 @@ def __init__( type=_idl_types.LiteralType( blob=_core_types.BlobType( format=_content_type_to_blob_format(algorithm_specification.input_content_type), - dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART + dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, ), ), description="", @@ -95,13 +91,12 @@ def __init__( "model": _interface_model.Variable( type=_idl_types.LiteralType( blob=_core_types.BlobType( - format="", - dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE + format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, ) ), - description="" + description="", ) - } + }, ), custom=MessageToDict(self._training_job_model.to_flyte_idl()), ) diff --git a/flytekit/common/tasks/sdk_dynamic.py b/flytekit/common/tasks/sdk_dynamic.py index 445b2dcb45..5d3fe92e57 100644 --- a/flytekit/common/tasks/sdk_dynamic.py +++ b/flytekit/common/tasks/sdk_dynamic.py @@ -1,21 +1,28 @@ from __future__ import absolute_import -import os as _os - import itertools as _itertools import math +import os as _os + import six as _six -from flytekit.common import constants as _constants, interface as _interface, sdk_bases as _sdk_bases, \ - launch_plan as _launch_plan, workflow as _workflow +from flytekit.common import constants as _constants +from flytekit.common import interface as _interface +from flytekit.common import launch_plan as _launch_plan +from flytekit.common import sdk_bases as _sdk_bases +from flytekit.common import workflow as _workflow from flytekit.common.core import identifier as _identifier from flytekit.common.exceptions import scopes as _exception_scopes from flytekit.common.mixins import registerable as _registerable -from flytekit.common.tasks import output as _task_output, sdk_runnable as _sdk_runnable, task as _task +from flytekit.common.tasks import output as _task_output +from flytekit.common.tasks import sdk_runnable as _sdk_runnable +from flytekit.common.tasks import task as _task from flytekit.common.types import helpers as _type_helpers from flytekit.common.utils import _dnsify from flytekit.configuration import internal as _internal_config -from flytekit.models import literals as _literal_models, dynamic_job as _dynamic_job, array_job as _array_job +from flytekit.models import array_job as _array_job +from flytekit.models import dynamic_job as _dynamic_job +from flytekit.models import literals as _literal_models class PromiseOutputReference(_task_output.OutputReference): @@ -47,8 +54,8 @@ def _append_node(generated_files, node, nodes, sub_task_node): # Upload inputs to working directory under /array_job.input_ref/inputs.pb input_path = _os.path.join(node.id, _constants.INPUT_FILE_NAME) generated_files[input_path] = _literal_models.LiteralMap( - literals={binding.var: binding.binding.to_literal_model() for binding in - sub_task_node.inputs}) + literals={binding.var: binding.binding.to_literal_model() for binding in sub_task_node.inputs} + ) class SdkDynamicTaskMixin(object): @@ -75,9 +82,9 @@ def _create_array_job(self, inputs_prefix): :param str inputs_prefix: :rtype: _array_job.ArrayJob """ - return _array_job.ArrayJob(parallelism=self._max_concurrency if self._max_concurrency else 0, - size=1, - min_successes=1) + return _array_job.ArrayJob( + parallelism=self._max_concurrency if self._max_concurrency else 0, size=1, min_successes=1, + ) @staticmethod def _can_run_as_array(task_type): @@ -109,9 +116,10 @@ def _produce_dynamic_job_spec(self, context, inputs): :param flytekit.models.literals.LiteralMap literal_map inputs: :rtype: (_dynamic_job.DynamicJobSpec, dict[Text, flytekit.models.common.FlyteIdlEntity]) """ - inputs_dict = _type_helpers.unpack_literal_map_to_sdk_python_std(inputs, { - k: _type_helpers.get_sdk_type_from_literal_type(v.type) for k, v in _six.iteritems(self.interface.inputs) - }) + inputs_dict = _type_helpers.unpack_literal_map_to_sdk_python_std( + inputs, + {k: _type_helpers.get_sdk_type_from_literal_type(v.type) for k, v in _six.iteritems(self.interface.inputs)}, + ) outputs_dict = { name: PromiseOutputReference(_type_helpers.get_sdk_type_from_literal_type(variable.type)) for name, variable in _six.iteritems(self.interface.outputs) @@ -120,13 +128,18 @@ def _produce_dynamic_job_spec(self, context, inputs): # Because users declare both inputs and outputs in their functions signatures, merge them together # before calling user code inputs_dict.update(outputs_dict) - yielded_sub_tasks = [sub_task for sub_task in - self._execute_user_code(context, inputs_dict) or []] + yielded_sub_tasks = [sub_task for sub_task in self._execute_user_code(context, inputs_dict) or []] upstream_nodes = list() - output_bindings = [_literal_models.Binding(var=name, binding=_interface.BindingData.from_python_std( - b.sdk_type.to_flyte_literal_type(), b.raw_value, upstream_nodes=upstream_nodes)) - for name, b in _six.iteritems(outputs_dict)] + output_bindings = [ + _literal_models.Binding( + var=name, + binding=_interface.BindingData.from_python_std( + b.sdk_type.to_flyte_literal_type(), b.raw_value, upstream_nodes=upstream_nodes, + ), + ) + for name, b in _six.iteritems(outputs_dict) + ] upstream_nodes = set(upstream_nodes) generated_files = {} @@ -159,7 +172,7 @@ def _produce_dynamic_job_spec(self, context, inputs): _internal_config.TASK_PROJECT.get() or _internal_config.PROJECT.get(), _internal_config.TASK_DOMAIN.get() or _internal_config.DOMAIN.get(), executable.platform_valid_name, - _internal_config.TASK_VERSION.get() or _internal_config.VERSION.get() + _internal_config.TASK_VERSION.get() or _internal_config.VERSION.get(), ) # Generate an id that's unique in the document (if the same task is used multiple times with @@ -196,7 +209,10 @@ def _produce_dynamic_job_spec(self, context, inputs): else: array_job = self._create_array_job(inputs_prefix=unique_node_id) node = sub_task_node.assign_id_and_return(unique_node_id) - array_job_index[sub_task_node.executable_sdk_object] = (array_job, node) + array_job_index[sub_task_node.executable_sdk_object] = ( + array_job, + node, + ) node_index = _six.text_type(array_job.size - 1) for k, node_output in _six.iteritems(sub_task_node.outputs): @@ -207,8 +223,8 @@ def _produce_dynamic_job_spec(self, context, inputs): # Upload inputs to working directory under /array_job.input_ref//inputs.pb input_path = _os.path.join(node.id, node_index, _constants.INPUT_FILE_NAME) generated_files[input_path] = _literal_models.LiteralMap( - literals={binding.var: binding.binding.to_literal_model() for binding in - sub_task_node.inputs}) + literals={binding.var: binding.binding.to_literal_model() for binding in sub_task_node.inputs} + ) else: node = sub_task_node.assign_id_and_return(unique_node_id) tasks.add(sub_task_node.executable_sdk_object) @@ -217,8 +233,11 @@ def _produce_dynamic_job_spec(self, context, inputs): # assign custom field to the ArrayJob properties computed. for task, (array_job, _) in _six.iteritems(array_job_index): # TODO: Reconstruct task template object instead of modifying an existing one? - tasks.add(task.assign_custom_and_return(array_job.to_dict()).assign_type_and_return( - _constants.SdkTaskType.CONTAINER_ARRAY_TASK)) + tasks.add( + task.assign_custom_and_return(array_job.to_dict()).assign_type_and_return( + _constants.SdkTaskType.CONTAINER_ARRAY_TASK + ) + ) # min_successes is absolute, it's computed as the reverse of allowed_failure_ratio and multiplied by the # total length of tasks to get an absolute count. @@ -228,7 +247,8 @@ def _produce_dynamic_job_spec(self, context, inputs): tasks=list(tasks), nodes=nodes, outputs=output_bindings, - subworkflows=list(sub_workflows)) + subworkflows=list(sub_workflows), + ) return dynamic_job_spec, generated_files @@ -252,17 +272,18 @@ def execute(self, context, inputs): if len(spec.nodes) == 0: return { _constants.OUTPUT_FILE_NAME: _literal_models.LiteralMap( - literals={binding.var: binding.binding.to_literal_model() for binding in spec.outputs}) + literals={binding.var: binding.binding.to_literal_model() for binding in spec.outputs} + ) } else: - generated_files.update({ - _constants.FUTURES_FILE_NAME: spec - }) + generated_files.update({_constants.FUTURES_FILE_NAME: spec}) return generated_files -class SdkDynamicTask(SdkDynamicTaskMixin, _sdk_runnable.SdkRunnableTask, metaclass=_sdk_bases.ExtendedSdkType): +class SdkDynamicTask( + SdkDynamicTaskMixin, _sdk_runnable.SdkRunnableTask, metaclass=_sdk_bases.ExtendedSdkType, +): """ This class includes the additional logic for building a task that executes @@ -271,27 +292,27 @@ class SdkDynamicTask(SdkDynamicTaskMixin, _sdk_runnable.SdkRunnableTask, metacla """ def __init__( - self, - task_function, - task_type, - discovery_version, - retries, - interruptible, - deprecated, - storage_request, - cpu_request, - gpu_request, - memory_request, - storage_limit, - cpu_limit, - gpu_limit, - memory_limit, - discoverable, - timeout, - allowed_failure_ratio, - max_concurrency, - environment, - custom + self, + task_function, + task_type, + discovery_version, + retries, + interruptible, + deprecated, + storage_request, + cpu_request, + gpu_request, + memory_request, + storage_limit, + cpu_limit, + gpu_limit, + memory_limit, + discoverable, + timeout, + allowed_failure_ratio, + max_concurrency, + environment, + custom, ): """ :param task_function: Function container user code. This will be executed via the SDK's engine. @@ -316,8 +337,25 @@ def __init__( :param dict[Text, T] custom: """ _sdk_runnable.SdkRunnableTask.__init__( - self, task_function, task_type, discovery_version, retries, interruptible, deprecated, - storage_request, cpu_request, gpu_request, memory_request, storage_limit, - cpu_limit, gpu_limit, memory_limit, discoverable, timeout, environment, custom) + self, + task_function, + task_type, + discovery_version, + retries, + interruptible, + deprecated, + storage_request, + cpu_request, + gpu_request, + memory_request, + storage_limit, + cpu_limit, + gpu_limit, + memory_limit, + discoverable, + timeout, + environment, + custom, + ) SdkDynamicTaskMixin.__init__(self, allowed_failure_ratio, max_concurrency) diff --git a/flytekit/common/tasks/sdk_runnable.py b/flytekit/common/tasks/sdk_runnable.py index a9b73c54f0..f868475037 100644 --- a/flytekit/common/tasks/sdk_runnable.py +++ b/flytekit/common/tasks/sdk_runnable.py @@ -8,14 +8,21 @@ import six as _six from flytekit import __version__ -from flytekit.common import interface as _interface, constants as _constants, sdk_bases as _sdk_bases -from flytekit.common.exceptions import user as _user_exceptions, scopes as _exception_scopes -from flytekit.common.tasks import task as _base_task, output as _task_output +from flytekit.common import constants as _constants +from flytekit.common import interface as _interface +from flytekit.common import sdk_bases as _sdk_bases +from flytekit.common.core.identifier import WorkflowExecutionIdentifier +from flytekit.common.exceptions import scopes as _exception_scopes +from flytekit.common.exceptions import user as _user_exceptions +from flytekit.common.tasks import output as _task_output +from flytekit.common.tasks import task as _base_task from flytekit.common.types import helpers as _type_helpers -from flytekit.configuration import sdk as _sdk_config, internal as _internal_config, resources as _resource_config +from flytekit.configuration import internal as _internal_config +from flytekit.configuration import resources as _resource_config +from flytekit.configuration import sdk as _sdk_config from flytekit.engines import loader as _engine_loader -from flytekit.models import literals as _literal_models, task as _task_models -from flytekit.common.core.identifier import WorkflowExecutionIdentifier +from flytekit.models import literals as _literal_models +from flytekit.models import task as _task_models class ExecutionParameters(object): @@ -97,23 +104,10 @@ def execution_id(self): class SdkRunnableContainer(_six.with_metaclass(_sdk_bases.ExtendedSdkType, _task_models.Container)): - def __init__( - self, - command, - args, - resources, - env, - config, + self, command, args, resources, env, config, ): - super(SdkRunnableContainer, self).__init__( - "", - command, - args, - resources, - env or {}, - config - ) + super(SdkRunnableContainer, self).__init__("", command, args, resources, env or {}, config) @property def args(self): @@ -159,25 +153,25 @@ class SdkRunnableTask(_six.with_metaclass(_sdk_bases.ExtendedSdkType, _base_task """ def __init__( - self, - task_function, - task_type, - discovery_version, - retries, - interruptible, - deprecated, - storage_request, - cpu_request, - gpu_request, - memory_request, - storage_limit, - cpu_limit, - gpu_limit, - memory_limit, - discoverable, - timeout, - environment, - custom + self, + task_function, + task_type, + discovery_version, + retries, + interruptible, + deprecated, + storage_request, + cpu_request, + gpu_request, + memory_request, + storage_limit, + cpu_limit, + gpu_limit, + memory_limit, + discoverable, + timeout, + environment, + custom, ): """ :param task_function: Function container user code. This will be executed via the SDK's engine. @@ -206,15 +200,13 @@ def __init__( _task_models.TaskMetadata( discoverable, _task_models.RuntimeMetadata( - _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, - __version__, - 'python' + _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, "python", ), timeout, _literal_models.RetryStrategy(retries), interruptible, discovery_version, - deprecated + deprecated, ), _interface.TypedInterface({}, {}), custom, @@ -227,8 +219,8 @@ def __init__( cpu_limit=cpu_limit, gpu_limit=gpu_limit, memory_limit=memory_limit, - environment=environment - ) + environment=environment, + ), ) self.id._name = "{}.{}".format(self.task_module, self.task_function_name) @@ -245,7 +237,7 @@ def add_inputs(self, inputs): """ self._validate_inputs(inputs) self.interface.inputs.update(inputs) - + @classmethod def promote_from_model(cls, base_model): # TODO: If the task exists in this container, we should be able to retrieve it. @@ -276,10 +268,7 @@ def validate(self): raise _user_exceptions.FlyteAssertion( "The task {} is invalid because not all inputs and outputs in the " "task function definition were specified in @outputs and @inputs. " - "We are missing definitions for {}.".format( - self, - missing_args - ) + "We are missing definitions for {}.".format(self, missing_args) ) @_exception_scopes.system_entry_point @@ -289,11 +278,18 @@ def unit_test(self, **input_map): literals. :returns: Depends on the behavior of the specific task in the unit engine. """ - return _engine_loader.get_engine('unit').get_task(self).execute( - _type_helpers.pack_python_std_map_to_literal_map(input_map, { - k: _type_helpers.get_sdk_type_from_literal_type(v.type) - for k, v in _six.iteritems(self.interface.inputs) - }) + return ( + _engine_loader.get_engine("unit") + .get_task(self) + .execute( + _type_helpers.pack_python_std_map_to_literal_map( + input_map, + { + k: _type_helpers.get_sdk_type_from_literal_type(v.type) + for k, v in _six.iteritems(self.interface.inputs) + }, + ) + ) ) @_exception_scopes.system_entry_point @@ -304,11 +300,18 @@ def local_execute(self, **input_map): :rtype: dict[Text, T] :returns: The output produced by this task in Python standard format. """ - return _engine_loader.get_engine('local').get_task(self).execute( - _type_helpers.pack_python_std_map_to_literal_map(input_map, { - k: _type_helpers.get_sdk_type_from_literal_type(v.type) - for k, v in _six.iteritems(self.interface.inputs) - }) + return ( + _engine_loader.get_engine("local") + .get_task(self) + .execute( + _type_helpers.pack_python_std_map_to_literal_map( + input_map, + { + k: _type_helpers.get_sdk_type_from_literal_type(v.type) + for k, v in _six.iteritems(self.interface.inputs) + }, + ) + ) ) def _execute_user_code(self, context, inputs): @@ -334,7 +337,7 @@ def _execute_user_code(self, context, inputs): execution_id=_six.text_type(WorkflowExecutionIdentifier.promote_from_model(context.execution_id)), stats=context.stats, logging=context.logging, - tmp_dir=context.working_directory + tmp_dir=context.working_directory, ), **inputs ) @@ -351,9 +354,10 @@ def execute(self, context, inputs): working directory (with the names provided), which will in turn allow Flyte Propeller to push along the workflow. Where as local engine will merely feed the outputs directly into the next node. """ - inputs_dict = _type_helpers.unpack_literal_map_to_sdk_python_std(inputs, { - k: _type_helpers.get_sdk_type_from_literal_type(v.type) for k, v in _six.iteritems(self.interface.inputs) - }) + inputs_dict = _type_helpers.unpack_literal_map_to_sdk_python_std( + inputs, + {k: _type_helpers.get_sdk_type_from_literal_type(v.type) for k, v in _six.iteritems(self.interface.inputs)}, + ) outputs_dict = { name: _task_output.OutputReference(_type_helpers.get_sdk_type_from_literal_type(variable.type)) for name, variable in _six.iteritems(self.interface.outputs) @@ -369,17 +373,17 @@ def execute(self, context, inputs): } def _get_container_definition( - self, - storage_request=None, - cpu_request=None, - gpu_request=None, - memory_request=None, - storage_limit=None, - cpu_limit=None, - gpu_limit=None, - memory_limit=None, - environment=None, - cls=None, + self, + storage_request=None, + cpu_request=None, + gpu_request=None, + memory_request=None, + storage_limit=None, + cpu_limit=None, + gpu_limit=None, + memory_limit=None, + environment=None, + cls=None, ): """ :param Text storage_request: @@ -406,61 +410,29 @@ def _get_container_definition( requests = [] if storage_request: requests.append( - _task_models.Resources.ResourceEntry( - _task_models.Resources.ResourceName.STORAGE, - storage_request - ) + _task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.STORAGE, storage_request) ) if cpu_request: - requests.append( - _task_models.Resources.ResourceEntry( - _task_models.Resources.ResourceName.CPU, - cpu_request - ) - ) + requests.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.CPU, cpu_request)) if gpu_request: - requests.append( - _task_models.Resources.ResourceEntry( - _task_models.Resources.ResourceName.GPU, - gpu_request - ) - ) + requests.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.GPU, gpu_request)) if memory_request: requests.append( - _task_models.Resources.ResourceEntry( - _task_models.Resources.ResourceName.MEMORY, - memory_request - ) + _task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.MEMORY, memory_request) ) limits = [] if storage_limit: limits.append( - _task_models.Resources.ResourceEntry( - _task_models.Resources.ResourceName.STORAGE, - storage_limit - ) + _task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.STORAGE, storage_limit) ) if cpu_limit: - limits.append( - _task_models.Resources.ResourceEntry( - _task_models.Resources.ResourceName.CPU, - cpu_limit - ) - ) + limits.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.CPU, cpu_limit)) if gpu_limit: - limits.append( - _task_models.Resources.ResourceEntry( - _task_models.Resources.ResourceName.GPU, - gpu_limit - ) - ) + limits.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.GPU, gpu_limit)) if memory_limit: limits.append( - _task_models.Resources.ResourceEntry( - _task_models.Resources.ResourceName.MEMORY, - memory_limit - ) + _task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.MEMORY, memory_limit) ) return (cls or SdkRunnableContainer)( @@ -474,11 +446,11 @@ def _get_container_definition( "--inputs", "{{.input}}", "--output-prefix", - "{{.outputPrefix}}" + "{{.outputPrefix}}", ], resources=_task_models.Resources(limits=limits, requests=requests), env=environment, - config={} + config={}, ) def _validate_inputs(self, inputs): diff --git a/flytekit/common/tasks/sidecar_task.py b/flytekit/common/tasks/sidecar_task.py index c31e7d518c..1f3b9effe5 100644 --- a/flytekit/common/tasks/sidecar_task.py +++ b/flytekit/common/tasks/sidecar_task.py @@ -1,17 +1,14 @@ from __future__ import absolute_import import six as _six - from flyteidl.core import tasks_pb2 as _core_task +from google.protobuf.json_format import MessageToDict as _MessageToDict +from flytekit.common import sdk_bases as _sdk_bases from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.tasks import sdk_dynamic as _sdk_dynamic from flytekit.common.tasks import sdk_runnable as _sdk_runnable -from flytekit.common import sdk_bases as _sdk_bases - from flytekit.models import task as _task_models -from google.protobuf.json_format import MessageToDict as _MessageToDict - from flytekit.plugins import k8s as _lazy_k8s @@ -22,26 +19,28 @@ class SdkSidecarTask(_sdk_runnable.SdkRunnableTask, metaclass=_sdk_bases.Extende """ - def __init__(self, - task_function, - task_type, - discovery_version, - retries, - interruptible, - deprecated, - storage_request, - cpu_request, - gpu_request, - memory_request, - storage_limit, - cpu_limit, - gpu_limit, - memory_limit, - discoverable, - timeout, - environment, - pod_spec=None, - primary_container_name=None): + def __init__( + self, + task_function, + task_type, + discovery_version, + retries, + interruptible, + deprecated, + storage_request, + cpu_request, + gpu_request, + memory_request, + storage_limit, + cpu_limit, + gpu_limit, + memory_limit, + discoverable, + timeout, + environment, + pod_spec=None, + primary_container_name=None, + ): """ :param _sdk_runnable.SdkRunnableTask sdk_runnable_task: :param generated_pb2.PodSpec pod_spec: @@ -76,9 +75,7 @@ def __init__(self, self.reconcile_partial_pod_spec_and_task(pod_spec, primary_container_name) - def reconcile_partial_pod_spec_and_task(self, - pod_spec, - primary_container_name): + def reconcile_partial_pod_spec_and_task(self, pod_spec, primary_container_name): """ Assigns the custom field as a the reconciled primary container and pod spec defintion. :param _sdk_runnable.SdkRunnableTask sdk_runnable_task: @@ -113,20 +110,23 @@ def reconcile_partial_pod_spec_and_task(self, resource_requirements = _lazy_k8s.io.api.core.v1.generated_pb2.ResourceRequirements() for resource in self._container.resources.limits: resource_requirements.limits[ - _core_task.Resources.ResourceName.Name(resource.name).lower()].CopyFrom( - _lazy_k8s.io.apimachinery.pkg.api.resource.generated_pb2.Quantity(string=resource.value)) + _core_task.Resources.ResourceName.Name(resource.name).lower() + ].CopyFrom(_lazy_k8s.io.apimachinery.pkg.api.resource.generated_pb2.Quantity(string=resource.value)) for resource in self._container.resources.requests: resource_requirements.requests[ - _core_task.Resources.ResourceName.Name(resource.name).lower()].CopyFrom( - _lazy_k8s.io.apimachinery.pkg.api.resource.generated_pb2.Quantity(string=resource.value)) + _core_task.Resources.ResourceName.Name(resource.name).lower() + ].CopyFrom(_lazy_k8s.io.apimachinery.pkg.api.resource.generated_pb2.Quantity(string=resource.value)) if resource_requirements.ByteSize(): # Important! Only copy over resource requirements if they are non-empty. container.resources.CopyFrom(resource_requirements) del container.env[:] container.env.extend( - [_lazy_k8s.io.api.core.v1.generated_pb2.EnvVar(name=key, value=val) for key, val in - _six.iteritems(self._container.env)]) + [ + _lazy_k8s.io.api.core.v1.generated_pb2.EnvVar(name=key, value=val) + for key, val in _six.iteritems(self._container.env) + ] + ) final_containers.append(container) @@ -134,14 +134,15 @@ def reconcile_partial_pod_spec_and_task(self, pod_spec.containers.extend(final_containers) sidecar_job_plugin = _task_models.SidecarJob( - pod_spec=pod_spec, - primary_container_name=primary_container_name, + pod_spec=pod_spec, primary_container_name=primary_container_name, ).to_flyte_idl() self.assign_custom_and_return(_MessageToDict(sidecar_job_plugin)) -class SdkDynamicSidecarTask(_sdk_dynamic.SdkDynamicTaskMixin, SdkSidecarTask, metaclass=_sdk_bases.ExtendedSdkType): +class SdkDynamicSidecarTask( + _sdk_dynamic.SdkDynamicTaskMixin, SdkSidecarTask, metaclass=_sdk_bases.ExtendedSdkType, +): """ This class includes the additional logic for building a task that runs as @@ -149,28 +150,30 @@ class SdkDynamicSidecarTask(_sdk_dynamic.SdkDynamicTaskMixin, SdkSidecarTask, me """ - def __init__(self, - task_function, - task_type, - discovery_version, - retries, - interruptible, - deprecated, - storage_request, - cpu_request, - gpu_request, - memory_request, - storage_limit, - cpu_limit, - gpu_limit, - memory_limit, - discoverable, - timeout, - allowed_failure_ratio, - max_concurrency, - environment, - pod_spec=None, - primary_container_name=None): + def __init__( + self, + task_function, + task_type, + discovery_version, + retries, + interruptible, + deprecated, + storage_request, + cpu_request, + gpu_request, + memory_request, + storage_limit, + cpu_limit, + gpu_limit, + memory_limit, + discoverable, + timeout, + allowed_failure_ratio, + max_concurrency, + environment, + pod_spec=None, + primary_container_name=None, + ): """ :param task_function: Function container user code. This will be executed via the SDK's engine. :param Text task_type: string describing the task type @@ -216,7 +219,7 @@ def __init__(self, timeout, environment, pod_spec=pod_spec, - primary_container_name=primary_container_name + primary_container_name=primary_container_name, ) _sdk_dynamic.SdkDynamicTaskMixin.__init__(self, allowed_failure_ratio, max_concurrency) diff --git a/flytekit/common/tasks/spark_task.py b/flytekit/common/tasks/spark_task.py index fd9ef24ebf..950fe6cc25 100644 --- a/flytekit/common/tasks/spark_task.py +++ b/flytekit/common/tasks/spark_task.py @@ -7,15 +7,19 @@ import os as _os import sys as _sys + import six as _six +from google.protobuf.json_format import MessageToDict as _MessageToDict + from flytekit.bin import entrypoint as _entrypoint from flytekit.common import constants as _constants from flytekit.common.exceptions import scopes as _exception_scopes -from flytekit.common.tasks import output as _task_output, sdk_runnable as _sdk_runnable +from flytekit.common.tasks import output as _task_output +from flytekit.common.tasks import sdk_runnable as _sdk_runnable from flytekit.common.types import helpers as _type_helpers -from flytekit.models import literals as _literal_models, task as _task_models +from flytekit.models import literals as _literal_models +from flytekit.models import task as _task_models from flytekit.plugins import pyspark as _pyspark -from google.protobuf.json_format import MessageToDict as _MessageToDict class GlobalSparkContext(object): @@ -36,7 +40,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): class SdkRunnableSparkContainer(_sdk_runnable.SdkRunnableContainer): - @property def args(self): """ @@ -51,20 +54,21 @@ class SdkSparkTask(_sdk_runnable.SdkRunnableTask): This class includes the additional logic for building a task that executes as a Spark Job. """ + def __init__( - self, - task_function, - task_type, - discovery_version, - retries, - interruptible, - deprecated, - discoverable, - timeout, - spark_type, - spark_conf, - hadoop_conf, - environment, + self, + task_function, + task_type, + discovery_version, + retries, + interruptible, + deprecated, + discoverable, + timeout, + spark_type, + spark_conf, + hadoop_conf, + environment, ): """ :param task_function: Function container user code. This will be executed via the SDK's engine. @@ -81,7 +85,7 @@ def __init__( """ spark_exec_path = _os.path.abspath(_entrypoint.__file__) - if spark_exec_path.endswith('.pyc'): + if spark_exec_path.endswith(".pyc"): spark_exec_path = spark_exec_path[:-1] spark_job = _task_models.SparkJob( @@ -125,9 +129,10 @@ def execute(self, context, inputs): working directory (with the names provided), which will in turn allow Flyte Propeller to push along the workflow. Where as local engine will merely feed the outputs directly into the next node. """ - inputs_dict = _type_helpers.unpack_literal_map_to_sdk_python_std(inputs, { - k: _type_helpers.get_sdk_type_from_literal_type(v.type) for k, v in _six.iteritems(self.interface.inputs) - }) + inputs_dict = _type_helpers.unpack_literal_map_to_sdk_python_std( + inputs, + {k: _type_helpers.get_sdk_type_from_literal_type(v.type) for k, v in _six.iteritems(self.interface.inputs)}, + ) outputs_dict = { name: _task_output.OutputReference(_type_helpers.get_sdk_type_from_literal_type(variable.type)) for name, variable in _six.iteritems(self.interface.outputs) @@ -142,7 +147,7 @@ def execute(self, context, inputs): execution_id=context.execution_id, stats=context.stats, logging=context.logging, - tmp_dir=context.working_directory + tmp_dir=context.working_directory, ), GlobalSparkContext.get_spark_context(), **inputs_dict @@ -153,10 +158,7 @@ def execute(self, context, inputs): ) } - def _get_container_definition( - self, - **kwargs - ): + def _get_container_definition(self, **kwargs): """ :rtype: SdkRunnableSparkContainer """ diff --git a/flytekit/common/tasks/task.py b/flytekit/common/tasks/task.py index 89166460e6..db60803ee3 100644 --- a/flytekit/common/tasks/task.py +++ b/flytekit/common/tasks/task.py @@ -1,26 +1,31 @@ from __future__ import absolute_import +import hashlib as _hashlib +import json as _json import uuid as _uuid import six as _six +from google.protobuf import json_format as _json_format +from google.protobuf import struct_pb2 as _struct -from google.protobuf import json_format as _json_format, struct_pb2 as _struct - -import hashlib as _hashlib -import json as _json - -from flytekit.common import ( - interface as _interfaces, nodes as _nodes, sdk_bases as _sdk_bases, workflow_execution as _workflow_execution -) +from flytekit.common import interface as _interfaces +from flytekit.common import nodes as _nodes +from flytekit.common import sdk_bases as _sdk_bases +from flytekit.common import workflow_execution as _workflow_execution from flytekit.common.core import identifier as _identifier from flytekit.common.exceptions import scopes as _exception_scopes -from flytekit.common.mixins import registerable as _registerable, hash as _hash_mixin, launchable as _launchable_mixin +from flytekit.common.exceptions import system as _system_exceptions +from flytekit.common.exceptions import user as _user_exceptions +from flytekit.common.mixins import hash as _hash_mixin +from flytekit.common.mixins import launchable as _launchable_mixin +from flytekit.common.mixins import registerable as _registerable +from flytekit.common.types import helpers as _type_helpers from flytekit.configuration import internal as _internal_config from flytekit.engines import loader as _engine_loader -from flytekit.models import common as _common_model, task as _task_model -from flytekit.models.core import workflow as _workflow_model, identifier as _identifier_model -from flytekit.common.exceptions import user as _user_exceptions, system as _system_exceptions -from flytekit.common.types import helpers as _type_helpers +from flytekit.models import common as _common_model +from flytekit.models import task as _task_model +from flytekit.models.core import identifier as _identifier_model +from flytekit.models.core import workflow as _workflow_model class SdkTask( @@ -32,7 +37,6 @@ class SdkTask( _launchable_mixin.LaunchableEntity, ) ): - def __init__(self, type, metadata, interface, custom, container=None): """ :param Text type: This is used to define additional extensions for use by Propeller or SDK. @@ -49,13 +53,13 @@ def __init__(self, type, metadata, interface, custom, container=None): _internal_config.PROJECT.get(), _internal_config.DOMAIN.get(), _uuid.uuid4().hex, - _internal_config.VERSION.get() + _internal_config.VERSION.get(), ), type, metadata, interface, custom, - container=container + container=container, ) @property @@ -99,7 +103,7 @@ def promote_from_model(cls, base_model): metadata=base_model.metadata, interface=_interfaces.TypedInterface.promote_from_model(base_model.interface), custom=base_model.custom, - container=base_model.container + container=base_model.container, ) # Override the newly generated name if one exists in the base model if not base_model.id.is_empty: @@ -133,10 +137,12 @@ def __call__(self, *args, **input_map): # TODO: Remove DEADBEEF return _nodes.SdkNode( id=None, - metadata=_workflow_model.NodeMetadata("DEADBEEF", self.metadata.timeout, self.metadata.retries, self.metadata.interruptible), + metadata=_workflow_model.NodeMetadata( + "DEADBEEF", self.metadata.timeout, self.metadata.retries, self.metadata.interruptible, + ), bindings=sorted(bindings, key=lambda b: b.var), upstream_nodes=upstream_nodes, - sdk_task=self + sdk_task=self, ) @_exception_scopes.system_entry_point @@ -154,7 +160,7 @@ def register(self, project, domain, name, version): self._id = id_to_register _engine_loader.get_engine().get_task(self).register(id_to_register) return _six.text_type(self.id) - except: + except Exception: self._id = old_id raise @@ -257,10 +263,7 @@ def _validate_outputs(self, outputs): ) def __repr__(self): - return "Flyte {task_type}: {interface}".format( - task_type=self.type, - interface=self.interface - ) + return "Flyte {task_type}: {interface}".format(task_type=self.type, interface=self.interface) def _python_std_input_map_to_literal_map(self, inputs): """ @@ -268,10 +271,10 @@ def _python_std_input_map_to_literal_map(self, inputs): to a LiteralMap :rtype: flytekit.models.literals.LiteralMap """ - return _type_helpers.pack_python_std_map_to_literal_map(inputs, { - k: _type_helpers.get_sdk_type_from_literal_type(v.type) - for k, v in _six.iteritems(self.interface.inputs) - }) + return _type_helpers.pack_python_std_map_to_literal_map( + inputs, + {k: _type_helpers.get_sdk_type_from_literal_type(v.type) for k, v in _six.iteritems(self.interface.inputs)}, + ) def _produce_deterministic_version(self, version=None): """ @@ -291,9 +294,13 @@ def _produce_deterministic_version(self, version=None): # 1) this method is used to compute the version portion of the identifier and # 2 ) the SDK will actually generate a unique name on every task instantiation which is not great for # the reproducibility this method attempts. - task_body = (self.type, self.metadata.to_flyte_idl().SerializeToString(deterministic=True), - self.interface.to_flyte_idl().SerializeToString(deterministic=True), custom) - return _hashlib.md5(str(task_body).encode('utf-8')).hexdigest() + task_body = ( + self.type, + self.metadata.to_flyte_idl().SerializeToString(deterministic=True), + self.interface.to_flyte_idl().SerializeToString(deterministic=True), + custom, + ) + return _hashlib.md5(str(task_body).encode("utf-8")).hexdigest() @_exception_scopes.system_entry_point def register_and_launch(self, project, domain, name=None, version=None, inputs=None): @@ -323,14 +330,22 @@ def register_and_launch(self, project, domain, name=None, version=None, inputs=N try: self._id = id_to_register _engine_loader.get_engine().get_task(self).register(id_to_register) - except: + except Exception: self._id = old_id raise return self.launch(project, domain, inputs=inputs) @_exception_scopes.system_entry_point - def launch_with_literals(self, project, domain, literal_inputs, name=None, notification_overrides=None, - label_overrides=None, annotation_overrides=None): + def launch_with_literals( + self, + project, + domain, + literal_inputs, + name=None, + notification_overrides=None, + label_overrides=None, + annotation_overrides=None, + ): """ Launches a single task execution and returns the execution identifier. :param Text project: @@ -345,13 +360,17 @@ def launch_with_literals(self, project, domain, literal_inputs, name=None, notif :param flytekit.models.common.Annotations annotation_overrides: :rtype: flytekit.common.workflow_execution.SdkWorkflowExecution """ - execution = _engine_loader.get_engine().get_task(self).launch( - project, - domain, - name=name, - inputs=literal_inputs, - notification_overrides=notification_overrides, - label_overrides=label_overrides, - annotation_overrides=annotation_overrides, + execution = ( + _engine_loader.get_engine() + .get_task(self) + .launch( + project, + domain, + name=name, + inputs=literal_inputs, + notification_overrides=notification_overrides, + label_overrides=label_overrides, + annotation_overrides=annotation_overrides, + ) ) return _workflow_execution.SdkWorkflowExecution.promote_from_model(execution) diff --git a/flytekit/common/types/base_sdk_types.py b/flytekit/common/types/base_sdk_types.py index 5430985d79..ec77a28424 100644 --- a/flytekit/common/types/base_sdk_types.py +++ b/flytekit/common/types/base_sdk_types.py @@ -1,13 +1,16 @@ from __future__ import absolute_import -from flytekit.models import literals as _literal_models, common as _common_models -from flytekit.common import sdk_bases as _sdk_bases -from flytekit.common.exceptions import user as _user_exceptions + import abc as _abc + import six as _six +from flytekit.common import sdk_bases as _sdk_bases +from flytekit.common.exceptions import user as _user_exceptions +from flytekit.models import common as _common_models +from flytekit.models import literals as _literal_models + class FlyteSdkType(_six.with_metaclass(_common_models.FlyteABCMeta, _sdk_bases.ExtendedSdkType)): - @_abc.abstractmethod def is_castable_from(cls, other): """ @@ -54,7 +57,6 @@ def __hash__(cls): class FlyteSdkValue(_six.with_metaclass(FlyteSdkType, _literal_models.Literal)): - @classmethod def from_flyte_idl(cls, pb2_object): """ @@ -83,7 +85,6 @@ def __call__(cls, *args, **kwargs): class Void(FlyteSdkValue): - @classmethod def is_castable_from(cls, other): """ @@ -106,8 +107,9 @@ def to_flyte_literal_type(cls): """ :rtype: flytekit.models.types.LiteralType """ - raise _user_exceptions.FlyteAssertion("A Void type does not have a literal type and cannot be used in this " - "manner.") + raise _user_exceptions.FlyteAssertion( + "A Void type does not have a literal type and cannot be used in this " "manner." + ) @classmethod def promote_from_model(cls, _): @@ -123,7 +125,7 @@ def short_class_string(cls): """ :rtype: Text """ - return 'Void' + return "Void" def __init__(self): super(Void, self).__init__(scalar=_literal_models.Scalar(none_type=_literal_models.Void())) diff --git a/flytekit/common/types/blobs.py b/flytekit/common/types/blobs.py index ec13e2993d..ce7a7498bb 100644 --- a/flytekit/common/types/blobs.py +++ b/flytekit/common/types/blobs.py @@ -1,23 +1,23 @@ from __future__ import absolute_import +import six as _six + from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.types import base_sdk_types as _base_sdk_types from flytekit.common.types.impl import blobs as _blob_impl -from flytekit.models import types as _idl_types, literals as _literals +from flytekit.models import literals as _literals +from flytekit.models import types as _idl_types from flytekit.models.core import types as _core_types -import six as _six - class BlobInstantiator(_base_sdk_types.InstantiableType): - @staticmethod def create_at_known_location(location): """ :param Text location: :rtype: flytekit.common.types.impl.blobs.Blob """ - return _blob_impl.Blob.create_at_known_location(location, mode='wb') + return _blob_impl.Blob.create_at_known_location(location, mode="wb") @staticmethod def fetch(remote_path, local_path=None): @@ -27,7 +27,7 @@ def fetch(remote_path, local_path=None): this location is NOT managed and the blob will not be cleaned up upon exit. :rtype: flytekit.common.types.impl.blobs.Blob """ - return _blob_impl.Blob.fetch(remote_path, mode='rb', local_path=local_path) + return _blob_impl.Blob.fetch(remote_path, mode="rb", local_path=local_path) def __call__(cls, *args, **kwargs): """ @@ -39,14 +39,13 @@ def __call__(cls, *args, **kwargs): :rtype: flytekit.common.types.impl.blobs.Blob """ if not args and not kwargs: - return _blob_impl.Blob.create_at_any_location(mode='wb') + return _blob_impl.Blob.create_at_any_location(mode="wb") else: return super(BlobInstantiator, cls).__call__(*args, **kwargs) # TODO: Make blobs and schemas pluggable class Blob(_six.with_metaclass(BlobInstantiator, _base_sdk_types.FlyteSdkValue)): - @classmethod def from_string(cls, string_value): """ @@ -55,7 +54,7 @@ def from_string(cls, string_value): """ if not string_value: _user_exceptions.FlyteValueException(string_value, "Cannot create a Blob from the provided path value.") - return cls(_blob_impl.Blob.from_string(string_value, mode='rb')) + return cls(_blob_impl.Blob.from_string(string_value, mode="rb")) @classmethod def is_castable_from(cls, other): @@ -86,10 +85,7 @@ def to_flyte_literal_type(cls): :rtype: flytekit.models.types.LiteralType """ return _idl_types.LiteralType( - blob=_core_types.BlobType( - format="", - dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE - ) + blob=_core_types.BlobType(format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE) ) @classmethod @@ -106,7 +102,7 @@ def short_class_string(cls): """ :rtype: Text """ - return 'Blob' + return "Blob" def __init__(self, value): """ @@ -126,19 +122,20 @@ def short_string(self): """ return "Blob(uri={}{})".format( self.scalar.blob.uri, - ", format={}".format(self.scalar.blob.metadata.type.format) if self.scalar.blob.metadata.type.format else "" + ", format={}".format(self.scalar.blob.metadata.type.format) + if self.scalar.blob.metadata.type.format + else "", ) class MultiPartBlobInstantiator(_base_sdk_types.InstantiableType): - @staticmethod def create_at_known_location(location): """ :param Text location: :rtype: flytekit.common.types.impl.blobs.MultiPartBlob """ - return _blob_impl.MultiPartBlob.create_at_known_location(location, mode='wb') + return _blob_impl.MultiPartBlob.create_at_known_location(location, mode="wb") @staticmethod def fetch(remote_path, local_path=None): @@ -148,7 +145,7 @@ def fetch(remote_path, local_path=None): this location is NOT managed and the blob will not be cleaned up upon exit. :rtype: flytekit.common.types.impl.blobs.MultiPartBlob """ - return _blob_impl.MultiPartBlob.fetch(remote_path, mode='rb', local_path=local_path) + return _blob_impl.MultiPartBlob.fetch(remote_path, mode="rb", local_path=local_path) def __call__(cls, *args, **kwargs): """ @@ -161,13 +158,12 @@ def __call__(cls, *args, **kwargs): :rtype: flytekit.common.types.impl.blobs.MultiPartBlob """ if not args and not kwargs: - return _blob_impl.MultiPartBlob.create_at_any_location(mode='wb') + return _blob_impl.MultiPartBlob.create_at_any_location(mode="wb") else: return super(MultiPartBlobInstantiator, cls).__call__(*args, **kwargs) class MultiPartBlob(_six.with_metaclass(MultiPartBlobInstantiator, _base_sdk_types.FlyteSdkValue)): - @classmethod def from_string(cls, string_value): """ @@ -175,9 +171,10 @@ def from_string(cls, string_value): :rtype: MultiPartBlob """ if not string_value: - _user_exceptions.FlyteValueException(string_value, "Cannot create a MultiPartBlob from the provided path " - "value.") - return cls(_blob_impl.MultiPartBlob.from_string(string_value, mode='rb')) + _user_exceptions.FlyteValueException( + string_value, "Cannot create a MultiPartBlob from the provided path " "value.", + ) + return cls(_blob_impl.MultiPartBlob.from_string(string_value, mode="rb")) @classmethod def is_castable_from(cls, other): @@ -208,10 +205,7 @@ def to_flyte_literal_type(cls): :rtype: flytekit.models.types.LiteralType """ return _idl_types.LiteralType( - blob=_core_types.BlobType( - format="", - dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART - ) + blob=_core_types.BlobType(format="", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART,) ) @classmethod @@ -228,7 +222,7 @@ def short_class_string(cls): """ :rtype: Text """ - return 'MultiPartBlob' + return "MultiPartBlob" def __init__(self, value): """ @@ -248,19 +242,20 @@ def short_string(self): """ return "MultiPartBlob(uri={}{})".format( self.scalar.blob.uri, - ", format={}".format(self.scalar.blob.metadata.type.format) if self.scalar.blob.metadata.type.format else "" + ", format={}".format(self.scalar.blob.metadata.type.format) + if self.scalar.blob.metadata.type.format + else "", ) class CsvInstantiator(BlobInstantiator): - @staticmethod def create_at_known_location(location): """ :param Text location: :rtype: flytekit.common.types.impl.blobs.CSV """ - return _blob_impl.Blob.create_at_known_location(location, mode='w', format='csv') + return _blob_impl.Blob.create_at_known_location(location, mode="w", format="csv") @staticmethod def fetch(remote_path, local_path=None): @@ -270,7 +265,7 @@ def fetch(remote_path, local_path=None): this location is NOT managed and the blob will not be cleaned up upon exit. :rtype: flytekit.common.types.impl.blobs.CSV """ - return _blob_impl.Blob.fetch(remote_path, local_path=local_path, mode='r', format='csv') + return _blob_impl.Blob.fetch(remote_path, local_path=local_path, mode="r", format="csv") def __call__(cls, *args, **kwargs): """ @@ -283,13 +278,12 @@ def __call__(cls, *args, **kwargs): :rtype: flytekit.common.types.impl.blobs.CSV """ if not args and not kwargs: - return _blob_impl.Blob.create_at_any_location(mode='w', format='csv') + return _blob_impl.Blob.create_at_any_location(mode="w", format="csv") else: return super(CsvInstantiator, cls).__call__(*args, **kwargs) class CSV(_six.with_metaclass(CsvInstantiator, Blob)): - @classmethod def from_string(cls, string_value): """ @@ -298,7 +292,7 @@ def from_string(cls, string_value): """ if not string_value: _user_exceptions.FlyteValueException(string_value, "Cannot create a CSV from the provided path value.") - return cls(_blob_impl.Blob.from_string(string_value, format='csv', mode='r')) + return cls(_blob_impl.Blob.from_string(string_value, format="csv", mode="r")) @classmethod def is_castable_from(cls, other): @@ -331,10 +325,7 @@ def to_flyte_literal_type(cls): :rtype: flytekit.models.types.LiteralType """ return _idl_types.LiteralType( - blob=_core_types.BlobType( - format="csv", - dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE - ) + blob=_core_types.BlobType(format="csv", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE,) ) @classmethod @@ -344,14 +335,14 @@ def promote_from_model(cls, literal_model): :param flytekit.models.literals.Literal literal_model: :rtype: CSV """ - return cls(_blob_impl.Blob.promote_from_model(literal_model.scalar.blob, mode='r')) + return cls(_blob_impl.Blob.promote_from_model(literal_model.scalar.blob, mode="r")) @classmethod def short_class_string(cls): """ :rtype: Text """ - return 'CSV' + return "CSV" def __init__(self, value): """ @@ -361,14 +352,13 @@ def __init__(self, value): class MultiPartCsvInstantiator(MultiPartBlobInstantiator): - @staticmethod def create_at_known_location(location): """ :param Text location: :rtype: flytekit.common.types.impl.blobs.MultiPartBlob """ - return _blob_impl.MultiPartBlob.create_at_known_location(location, mode='w', format="csv") + return _blob_impl.MultiPartBlob.create_at_known_location(location, mode="w", format="csv") @staticmethod def fetch(remote_path, local_path=None): @@ -378,7 +368,7 @@ def fetch(remote_path, local_path=None): this location is NOT managed and the blob will not be cleaned up upon exit. :rtype: flytekit.common.types.impl.blobs.MultiPartCSV """ - return _blob_impl.MultiPartBlob.fetch(remote_path, local_path=local_path, mode='r', format="csv") + return _blob_impl.MultiPartBlob.fetch(remote_path, local_path=local_path, mode="r", format="csv") def __call__(cls, *args, **kwargs): """ @@ -391,13 +381,12 @@ def __call__(cls, *args, **kwargs): :rtype: flytekit.common.types.impl.blobs.MultiPartCSV """ if not args and not kwargs: - return _blob_impl.MultiPartBlob.create_at_any_location(mode='w', format="csv") + return _blob_impl.MultiPartBlob.create_at_any_location(mode="w", format="csv") else: return super(MultiPartCsvInstantiator, cls).__call__(*args, **kwargs) class MultiPartCSV(_six.with_metaclass(MultiPartCsvInstantiator, MultiPartBlob)): - @classmethod def from_string(cls, string_value): """ @@ -406,10 +395,9 @@ def from_string(cls, string_value): """ if not string_value: _user_exceptions.FlyteValueException( - string_value, - "Cannot create a MultiPartCSV from the provided path value." + string_value, "Cannot create a MultiPartCSV from the provided path value.", ) - return cls(_blob_impl.MultiPartBlob.from_string(string_value, format='csv', mode='r')) + return cls(_blob_impl.MultiPartBlob.from_string(string_value, format="csv", mode="r")) @classmethod def is_castable_from(cls, other): @@ -431,8 +419,7 @@ def from_python_std(cls, t_value): elif isinstance(t_value, _blob_impl.MultiPartBlob): if t_value.metadata.type.format != "csv": raise _user_exceptions.FlyteValueException( - t_value, - "Multi Part Blob is in incorrect format. Expected CSV." + t_value, "Multi Part Blob is in incorrect format. Expected CSV." ) blob = t_value else: @@ -445,10 +432,7 @@ def to_flyte_literal_type(cls): :rtype: flytekit.models.types.LiteralType """ return _idl_types.LiteralType( - blob=_core_types.BlobType( - format="csv", - dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART - ) + blob=_core_types.BlobType(format="csv", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART,) ) @classmethod @@ -458,14 +442,14 @@ def promote_from_model(cls, literal_model): :param flytekit.models.literals.Literal literal_model: :rtype: MultiPartCSV """ - return cls(_blob_impl.MultiPartBlob.promote_from_model(literal_model.scalar.blob, mode='r')) + return cls(_blob_impl.MultiPartBlob.promote_from_model(literal_model.scalar.blob, mode="r")) @classmethod def short_class_string(cls): """ :rtype: Text """ - return 'MultiPartCSV' + return "MultiPartCSV" def __init__(self, value): """ diff --git a/flytekit/common/types/containers.py b/flytekit/common/types/containers.py index 211ac7fa08..46fe545aac 100644 --- a/flytekit/common/types/containers.py +++ b/flytekit/common/types/containers.py @@ -1,11 +1,13 @@ from __future__ import absolute_import -import six as _six import json as _json +import six as _six + from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.types import base_sdk_types as _base_sdk_types -from flytekit.models import types as _idl_types, literals as _literals +from flytekit.models import literals as _literals +from flytekit.models import types as _idl_types class CollectionType(_base_sdk_types.FlyteSdkType): @@ -21,7 +23,7 @@ def sub_type(cls): return cls._sub_type def __eq__(cls, other): - return hasattr(other, 'sub_type') and cls.sub_type == other.sub_type + return hasattr(other, "sub_type") and cls.sub_type == other.sub_type def __hash__(cls): # Python 3 checks complain if hash isn't implemented at the same time as equals @@ -47,7 +49,6 @@ def __len__(self): class TypedListImpl(_six.with_metaclass(TypedCollectionType, ListImpl)): - @classmethod def from_string(cls, string_value): """ @@ -59,11 +60,13 @@ def from_string(cls, string_value): items = _json.loads(string_value) except ValueError: raise _user_exceptions.FlyteTypeException( - _six.text_type, cls, additional_msg='String not parseable to json {}'.format(string_value)) + _six.text_type, cls, additional_msg="String not parseable to json {}".format(string_value), + ) if type(items) != list: raise _user_exceptions.FlyteTypeException( - _six.text_type, cls, additional_msg='String is not a list {}'.format(string_value)) + _six.text_type, cls, additional_msg="String is not a list {}".format(string_value), + ) # Instead of recursively calling from_string(), we're changing to from_python_std() instead because json # loading naturally interprets all layers, not just the outer layer. @@ -113,17 +116,13 @@ def short_class_string(cls): """ :rtype: Text """ - return 'List<{}>'.format(cls.sub_type.short_class_string()) + return "List<{}>".format(cls.sub_type.short_class_string()) def __init__(self, value): """ :param list[flytekit.common.types.base_sdk_types.FlyteSdkValue] value: List value to wrap """ - super(TypedListImpl, self).__init__( - collection=_literals.LiteralCollection( - literals=value - ) - ) + super(TypedListImpl, self).__init__(collection=_literals.LiteralCollection(literals=value)) def to_python_std(self): """ @@ -140,9 +139,7 @@ def short_string(self): if len(self.collection.literals) > num_to_print: to_print.append("...") return "{}(len={}, [{}])".format( - type(self).short_class_string(), - len(self.collection.literals), - ", ".join(to_print) + type(self).short_class_string(), len(self.collection.literals), ", ".join(to_print), ) def verbose_string(self): @@ -152,8 +149,5 @@ def verbose_string(self): return "{}(\n\tlen={},\n\t[\n\t\t{}\n\t]\n)".format( type(self).short_class_string(), len(self.collection.literals), - ",\n\t\t".join( - "\n\t\t".join(v.verbose_string().splitlines()) - for v in self.collection.literals - ) + ",\n\t\t".join("\n\t\t".join(v.verbose_string().splitlines()) for v in self.collection.literals), ) diff --git a/flytekit/common/types/helpers.py b/flytekit/common/types/helpers.py index 2bb19bdf4d..a5a52f56b6 100644 --- a/flytekit/common/types/helpers.py +++ b/flytekit/common/types/helpers.py @@ -1,10 +1,13 @@ from __future__ import absolute_import +import importlib as _importlib + import six as _six -from flytekit.models import literals as _literal_models -from flytekit.common.exceptions import user as _user_exceptions, scopes as _exception_scopes + +from flytekit.common.exceptions import scopes as _exception_scopes +from flytekit.common.exceptions import user as _user_exceptions from flytekit.configuration import sdk as _sdk_config -import importlib as _importlib +from flytekit.models import literals as _literal_models class _TypeEngineLoader(object): @@ -26,14 +29,13 @@ def _load_engines(cls): raise _user_exceptions.FlyteValueException( module, "Failed to load the type engine because the attribute named '{}' could not be found" - "in the module '{}'.".format( - attr, module_path - ) + "in the module '{}'.".format(attr, module_path), ) engine_impl = getattr(module, attr)() cls._LOADED_ENGINES.append(engine_impl) from flytekit.type_engines.default.flyte import FlyteDefaultTypeEngine as _DefaultEngine + cls._LOADED_ENGINES.append(_DefaultEngine()) @classmethod @@ -66,8 +68,9 @@ def get_sdk_type_from_literal_type(literal_type): out = e.get_sdk_type_from_literal_type(literal_type) if out is not None: return out - raise _user_exceptions.FlyteValueException(literal_type, "Could not resolve to a type implementation for this " - "value.") + raise _user_exceptions.FlyteValueException( + literal_type, "Could not resolve to a type implementation for this " "value." + ) def infer_sdk_type_from_literal(literal): @@ -126,8 +129,4 @@ def pack_python_std_map_to_literal_map(std_map, type_map): :rtype: flytekit.models.literals.LiteralMap :raises: flytekit.common.exceptions.user.FlyteTypeException """ - return _literal_models.LiteralMap( - literals={ - k: v.from_python_std(std_map[k]) for k, v in _six.iteritems(type_map) - } - ) + return _literal_models.LiteralMap(literals={k: v.from_python_std(std_map[k]) for k, v in _six.iteritems(type_map)}) diff --git a/flytekit/common/types/impl/blobs.py b/flytekit/common/types/impl/blobs.py index dcf2947e6f..709a92e4f7 100644 --- a/flytekit/common/types/impl/blobs.py +++ b/flytekit/common/types/impl/blobs.py @@ -2,42 +2,42 @@ import os as _os import shutil as _shutil -import six as _six import sys as _sys import uuid as _uuid -from flytekit.common import sdk_bases as _sdk_bases, utils as _utils -from flytekit.common.exceptions import user as _user_exceptions, scopes as _exception_scopes + +import six as _six + +from flytekit.common import sdk_bases as _sdk_bases +from flytekit.common import utils as _utils +from flytekit.common.exceptions import scopes as _exception_scopes +from flytekit.common.exceptions import user as _user_exceptions from flytekit.interfaces.data import data_proxy as _data_proxy from flytekit.models import literals as _literal_models from flytekit.models.core import types as _core_types class Blob(_six.with_metaclass(_sdk_bases.ExtendedSdkType, _literal_models.Blob)): - - def __init__(self, remote_path, mode='rb', format=None): + def __init__(self, remote_path, mode="rb", format=None): """ :param Text remote_path: Path to location where the Blob should be synced to. :param Text mode: File access mode. 'a' and '+' are forbidden. A blob can only be written or read at a time. :param Text format: Format """ - if '+' in mode or 'a' in mode or ('w' in mode and 'r' in mode): + if "+" in mode or "a" in mode or ("w" in mode and "r" in mode): raise _user_exceptions.FlyteAssertion("A blob cannot be read and written at the same time") self._mode = mode self._local_path = None self._file = None super(Blob, self).__init__( _literal_models.BlobMetadata( - type=_core_types.BlobType( - format or "", - _core_types.BlobType.BlobDimensionality.SINGLE - ) + type=_core_types.BlobType(format or "", _core_types.BlobType.BlobDimensionality.SINGLE) ), - remote_path + remote_path, ) @classmethod @_exception_scopes.system_entry_point - def from_python_std(cls, t_value, mode='wb', format=None): + def from_python_std(cls, t_value, mode="wb", format=None): """ :param T t_value: :param Text mode: File access mode. 'a' and '+' are forbidden. A blob can only be written or read at a time. @@ -59,12 +59,12 @@ def from_python_std(cls, t_value, mode='wb', format=None): type(t_value), {_six.text_type, str, Blob}, received_value=t_value, - additional_msg="Unable to create Blob from user-provided value." + additional_msg="Unable to create Blob from user-provided value.", ) @classmethod @_exception_scopes.system_entry_point - def from_string(cls, t_value, mode='wb', format=None): + def from_string(cls, t_value, mode="wb", format=None): """ :param T t_value: :param Text mode: Read or write mode of the object. @@ -75,7 +75,7 @@ def from_string(cls, t_value, mode='wb', format=None): @classmethod @_exception_scopes.system_entry_point - def create_at_known_location(cls, known_remote_location, mode='wb', format=None): + def create_at_known_location(cls, known_remote_location, mode="wb", format=None): """ :param Text known_remote_location: The location to which to write the object. Usually an s3 path. :param Text mode: @@ -86,7 +86,7 @@ def create_at_known_location(cls, known_remote_location, mode='wb', format=None) @classmethod @_exception_scopes.system_entry_point - def create_at_any_location(cls, mode='wb', format=None): + def create_at_any_location(cls, mode="wb", format=None): """ :param Text mode: :param Text format: @@ -96,7 +96,7 @@ def create_at_any_location(cls, mode='wb', format=None): @classmethod @_exception_scopes.system_entry_point - def fetch(cls, remote_path, local_path=None, overwrite=False, mode='rb', format=None): + def fetch(cls, remote_path, local_path=None, overwrite=False, mode="rb", format=None): """ :param Text remote_path: The location from which to fetch the object. Usually an s3 path. :param Text local_path: [Optional] A local path to which to download the object. If specified, the object @@ -112,7 +112,7 @@ def fetch(cls, remote_path, local_path=None, overwrite=False, mode='rb', format= return blob @classmethod - def promote_from_model(cls, model, mode='rb'): + def promote_from_model(cls, model, mode="rb"): """ :param flytekit.models.literals.Blob model: :param Text mode: Read or write mode of the object. @@ -153,9 +153,9 @@ def __enter__(self): raise _user_exceptions.FlyteAssertion("Only one reference can be open to a blob at a time.") if self.local_path is None: - if 'r' in self.mode: + if "r" in self.mode: self.download() - elif 'w' in self.mode: + elif "w" in self.mode: self._generate_local_path() self._file = open(self.local_path, self.mode) @@ -166,7 +166,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): if self._file is not None and not self._file.closed: self._file.close() self._file = None - if 'w' in self.mode: + if "w" in self.mode: self.upload() return False @@ -176,7 +176,8 @@ def _generate_local_path(self): "No temporary file system is present. Either call this method from within the " "context of a task or surround with a 'with LocalTestFileSystem():' block. Or " "specify a path when calling this function. Note: Cleanup is not automatic when a " - "path is specified.") + "path is specified." + ) self._local_path = _data_proxy.LocalWorkingDirectoryContext.get().get_named_tempfile(_uuid.uuid4().hex) @_exception_scopes.system_entry_point @@ -188,7 +189,7 @@ def download(self, local_path=None, overwrite=False): :param bool overwrite: If true and local_path is specified, we will download the blob and overwrite an existing file at that location. Default is False. """ - if 'r' not in self._mode: + if "r" not in self._mode: raise _user_exceptions.FlyteAssertion("Cannot download a write-only blob!") if local_path: @@ -200,18 +201,11 @@ def download(self, local_path=None, overwrite=False): if overwrite or not _os.path.exists(self.local_path): # TODO: Introduce system logging # logging.info("Getting {} -> {}".format(self.remote_location, self.local_path)) - _data_proxy.Data.get_data( - self.remote_location, - self.local_path, - is_multipart=False - ) + _data_proxy.Data.get_data(self.remote_location, self.local_path, is_multipart=False) else: raise _user_exceptions.FlyteAssertion( "Cannot download blob to a location that already exists when overwrite is not set to True. " - "Attempted download from {} -> {}".format( - self.remote_location, - self.local_path - ) + "Attempted download from {} -> {}".format(self.remote_location, self.local_path) ) @_exception_scopes.system_entry_point @@ -219,40 +213,34 @@ def upload(self): """ Upload the blob to the remote location """ - if 'w' not in self.mode: + if "w" not in self.mode: raise _user_exceptions.FlyteAssertion("Cannot upload a read-only blob!") elif not self.local_path: - raise _user_exceptions.FlyteAssertion("The Blob is not currently backed by a local file and therefore " - "cannot be uploaded. Please write to this Blob before attempting " - "an upload.") + raise _user_exceptions.FlyteAssertion( + "The Blob is not currently backed by a local file and therefore " + "cannot be uploaded. Please write to this Blob before attempting " + "an upload." + ) else: # TODO: Introduce system logging # logging.info("Putting {} -> {}".format(self.local_path, self.remote_location)) - _data_proxy.Data.put_data( - self.local_path, - self.remote_location, - is_multipart=False - ) + _data_proxy.Data.put_data(self.local_path, self.remote_location, is_multipart=False) class MultiPartBlob(_six.with_metaclass(_sdk_bases.ExtendedSdkType, _literal_models.Blob)): - - def __init__(self, remote_path, mode='rb', format=None): + def __init__(self, remote_path, mode="rb", format=None): """ :param Text remote_path: Path to location where the Blob should be synced to. :param Text mode: File access mode. 'a' and '+' are forbidden. A blob can only be written or read at a time. :param Text format: Format of underlying blob pieces. """ - remote_path = remote_path.strip().rstrip('/') + '/' + remote_path = remote_path.strip().rstrip("/") + "/" super(MultiPartBlob, self).__init__( _literal_models.BlobMetadata( - type=_core_types.BlobType( - format or "", - _core_types.BlobType.BlobDimensionality.MULTIPART - ) + type=_core_types.BlobType(format or "", _core_types.BlobType.BlobDimensionality.MULTIPART) ), - remote_path + remote_path, ) self._is_managed = False self._blobs = [] @@ -260,7 +248,7 @@ def __init__(self, remote_path, mode='rb', format=None): self._mode = mode @classmethod - def promote_from_model(cls, model, mode='rb'): + def promote_from_model(cls, model, mode="rb"): """ :param flytekit.models.literals.Blob model: :param Text mode: File access mode. 'a' and '+' are forbidden. A blob can only be written or read at a time. @@ -270,7 +258,7 @@ def promote_from_model(cls, model, mode='rb'): @classmethod @_exception_scopes.system_entry_point - def create_at_known_location(cls, known_remote_location, mode='wb', format=None): + def create_at_known_location(cls, known_remote_location, mode="wb", format=None): """ :param Text known_remote_location: The location to which to write the object. Usually an s3 path. :param Text mode: @@ -281,7 +269,7 @@ def create_at_known_location(cls, known_remote_location, mode='wb', format=None) @classmethod @_exception_scopes.system_entry_point - def create_at_any_location(cls, mode='wb', format=None): + def create_at_any_location(cls, mode="wb", format=None): """ :param Text mode: :param Text format: @@ -291,7 +279,7 @@ def create_at_any_location(cls, mode='wb', format=None): @classmethod @_exception_scopes.system_entry_point - def fetch(cls, remote_path, local_path=None, overwrite=False, mode='rb', format=None): + def fetch(cls, remote_path, local_path=None, overwrite=False, mode="rb", format=None): """ :param Text remote_path: The location from which to fetch the object. Usually an s3 path. :param Text local_path: [Optional] A local path to which to download the object. If specified, the object @@ -308,7 +296,7 @@ def fetch(cls, remote_path, local_path=None, overwrite=False, mode='rb', format= @classmethod @_exception_scopes.system_entry_point - def from_python_std(cls, t_value, mode='wb', format=None): + def from_python_std(cls, t_value, mode="wb", format=None): """ :param T t_value: :param Text mode: Read or write mode of the object. @@ -331,12 +319,12 @@ def from_python_std(cls, t_value, mode='wb', format=None): type(t_value), {str, _six.text_type, MultiPartBlob}, received_value=t_value, - additional_msg="Unable to create Blob from user-provided value." + additional_msg="Unable to create Blob from user-provided value.", ) @classmethod @_exception_scopes.system_entry_point - def from_string(cls, t_value, mode='wb', format=None): + def from_string(cls, t_value, mode="wb", format=None): """ :param T t_value: :param Text mode: Read or write mode of the object. @@ -350,7 +338,7 @@ def __enter__(self): """ :rtype: list[typing.BinaryIO] """ - if 'r' not in self.mode: + if "r" not in self.mode: raise _user_exceptions.FlyteAssertion("Do not enter context to write to directory. Call create_piece") try: @@ -363,8 +351,8 @@ def __enter__(self): "path is specified." ) self._directory = _utils.AutoDeletingTempDir( - _uuid.uuid4().hex, - tmp_dir=_data_proxy.LocalWorkingDirectoryContext.get().name) + _uuid.uuid4().hex, tmp_dir=_data_proxy.LocalWorkingDirectoryContext.get().name, + ) self._is_managed = True self._directory.__enter__() # TODO: Introduce system logging @@ -375,16 +363,13 @@ def __enter__(self): self._blobs = [] file_handles = [] for local_path in sorted(self._directory.list_dir(), key=lambda x: x.lower()): - b = Blob( - _os.path.join(self.remote_location, _os.path.basename(local_path)), - mode=self.mode, - ) + b = Blob(_os.path.join(self.remote_location, _os.path.basename(local_path)), mode=self.mode,) b._local_path = local_path file_handles.append(b.__enter__()) self._blobs.append(b) return file_handles - except: + except Exception: # Exit is idempotent so close partially opened context that way exc_type, exc_obj, exc_tb = _sys.exc_info() self.__exit__(exc_type, exc_obj, exc_tb) @@ -437,18 +422,16 @@ def create_part(self, name=None): used to enforce ordering. If not provided, the name is randomly generated. :rtype: Blob """ - if 'w' not in self.mode: + if "w" not in self.mode: raise _user_exceptions.FlyteAssertion("Cannot create a blob in a read-only multipart blob") if name is None: name = _uuid.uuid4().hex - if ':' in name or '/' in name: + if ":" in name or "/" in name: raise _user_exceptions.FlyteAssertion( - name, - "Cannot create a part of a multi-part object with ':' or '/' in the name.") + name, "Cannot create a part of a multi-part object with ':' or '/' in the name.", + ) return Blob.create_at_known_location( - _os.path.join(self.remote_location, name), - mode=self.mode, - format=self.metadata.type.format + _os.path.join(self.remote_location, name), mode=self.mode, format=self.metadata.type.format, ) @_exception_scopes.system_entry_point @@ -461,7 +444,7 @@ def download(self, local_path=None, overwrite=False): :param bool overwrite: If true and local_path is specified, we will download the blob pieces and overwrite any existing files at that location. Default is False. """ - if 'r' not in self.mode: + if "r" not in self.mode: raise _user_exceptions.FlyteAssertion("Cannot download a write-only object!") if local_path: @@ -471,7 +454,8 @@ def download(self, local_path=None, overwrite=False): "No temporary file system is present. Either call this method from within the " "context of a task or surround with a 'with LocalTestFileSystem():' block. Or " "specify a path when calling this function. Note: Cleanup is not automatic when a " - "path is specified.") + "path is specified." + ) else: local_path = _data_proxy.LocalWorkingDirectoryContext.get().get_named_tempfile(_uuid.uuid4().hex) @@ -485,10 +469,7 @@ def download(self, local_path=None, overwrite=False): else: raise _user_exceptions.FlyteAssertion( "Cannot download multi-part blob to a location that already exists when overwrite is not set to True. " - "Attempted download from {} -> {}".format( - self.remote_location, - self.local_path - ) + "Attempted download from {} -> {}".format(self.remote_location, self.local_path) ) @_exception_scopes.system_entry_point @@ -496,18 +477,16 @@ def upload(self): """ Upload the multi-part blob to the remote location """ - if 'w' not in self.mode: + if "w" not in self.mode: raise _user_exceptions.FlyteAssertion("Cannot upload a read-only multi-part blob!") elif not self.local_path: - raise _user_exceptions.FlyteAssertion("The multi-part blob is not currently backed by a local directoru " - "and therefore cannot be uploaded. Please write to this before " - "attempting an upload.") + raise _user_exceptions.FlyteAssertion( + "The multi-part blob is not currently backed by a local directoru " + "and therefore cannot be uploaded. Please write to this before " + "attempting an upload." + ) else: # TODO: Introduce system logging # logging.info("Putting {} -> {}".format(self.local_path, self.remote_location)) - _data_proxy.Data.put_data( - self.local_path, - self.remote_location, - is_multipart=True - ) + _data_proxy.Data.put_data(self.local_path, self.remote_location, is_multipart=True) diff --git a/flytekit/common/types/impl/schema.py b/flytekit/common/types/impl/schema.py index e5fd344ec6..1e544d3bd6 100644 --- a/flytekit/common/types/impl/schema.py +++ b/flytekit/common/types/impl/schema.py @@ -1,19 +1,25 @@ from __future__ import absolute_import import collections as _collections -from flytekit.plugins import numpy as _np -from flytekit.plugins import pandas as _pd import os as _os -import six as _six import uuid as _uuid -from flytekit.common import utils as _utils, sdk_bases as _sdk_bases -from flytekit.common.types import primitives as _primitives, base_sdk_types as _base_sdk_types, helpers as _helpers +import six as _six + +from flytekit.common import sdk_bases as _sdk_bases +from flytekit.common import utils as _utils +from flytekit.common.exceptions import scopes as _exception_scopes +from flytekit.common.exceptions import user as _user_exceptions +from flytekit.common.types import base_sdk_types as _base_sdk_types +from flytekit.common.types import helpers as _helpers +from flytekit.common.types import primitives as _primitives from flytekit.common.types.impl import blobs as _blob_impl -from flytekit.common.exceptions import user as _user_exceptions, scopes as _exception_scopes -from flytekit.interfaces.data import data_proxy as _data_proxy -from flytekit.models import types as _type_models, literals as _literal_models from flytekit.configuration import sdk as _sdk_config +from flytekit.interfaces.data import data_proxy as _data_proxy +from flytekit.models import literals as _literal_models +from flytekit.models import types as _type_models +from flytekit.plugins import numpy as _np +from flytekit.plugins import pandas as _pd # Note: For now, this is only for basic type-checking. We need not differentiate between TINYINT, BIGINT, # and INT or DOUBLE and FLOAT, VARCHAR and STRING, etc. as we will unpack into appropriate Python @@ -30,7 +36,7 @@ def get_supported_literal_types_to_pandas_types(): _primitives.Boolean.to_flyte_literal_type(): {_np.bool}, _primitives.Datetime.to_flyte_literal_type(): {_np.datetime64}, _primitives.Timedelta.to_flyte_literal_type(): {_np.timedelta64}, - _primitives.String.to_flyte_literal_type(): {_np.object_, _np.str_, _np.string_} + _primitives.String.to_flyte_literal_type(): {_np.object_, _np.str_, _np.string_}, } return _SUPPORTED_LITERAL_TYPE_TO_PANDAS_TYPES @@ -42,8 +48,7 @@ def get_supported_literal_types_to_pandas_types(): # work-around where we create an external table with the appropriate schema and write the data to our desired # location. The issue here is that the table information in the meta-store might not get cleaned up during a partial # failure. -_HIVE_QUERY_FORMATTER = \ - """ +_HIVE_QUERY_FORMATTER = """ {stage_query_str} CREATE TEMPORARY TABLE {table}_tmp AS {query_str}; @@ -59,8 +64,7 @@ def get_supported_literal_types_to_pandas_types(): # Once https://issues.apache.org/jira/browse/HIVE-12860 is resolved. We will prefer the following syntax because it # guarantees cleanup on partial failures. -_HIVE_QUERY_FORMATTER_V2 = \ - """ +_HIVE_QUERY_FORMATTER_V2 = """ CREATE TEMPORARY TABLE {table} AS {query_str}; INSERT OVERWRITE DIRECTORY '{url}' STORED AS PARQUET @@ -70,8 +74,7 @@ def get_supported_literal_types_to_pandas_types(): # Set location in both parts of this query so in case of a partial failure, we will always have some data backing a # partition. -_WRITE_HIVE_PARTITION_QUERY_FORMATTER = \ - """ +_WRITE_HIVE_PARTITION_QUERY_FORMATTER = """ ALTER TABLE {write_table} ADD IF NOT EXISTS {partition_string} LOCATION '{url}'; ALTER TABLE {write_table} {partition_string} SET LOCATION '{url}'; @@ -79,24 +82,21 @@ def get_supported_literal_types_to_pandas_types(): def _format_insert_partition_query(table_name, partition_string, remote_location): - table_pieces = table_name.split('.') + table_pieces = table_name.split(".") if len(table_pieces) > 1: # Hive shell commands don't allow us to alter tables and select databases in the table specification. So # we split the table name and use the 'use' command to choose the correct database. prefix = "use {};\n".format(table_pieces[0]) - table_name = '.'.join(table_pieces[1:]) + table_name = ".".join(table_pieces[1:]) else: prefix = "" return prefix + _WRITE_HIVE_PARTITION_QUERY_FORMATTER.format( - write_table=table_name, - partition_string=partition_string, - url=remote_location + write_table=table_name, partition_string=partition_string, url=remote_location ) class _SchemaIO(object): - def __init__(self, schema_instance, local_dir, mode): """ :param Schema schema_instance: @@ -146,7 +146,7 @@ def seek(self, index): if index < 0 or index > self.chunk_count: raise _user_exceptions.FlyteValueException( index, - "Attempting to seek to a chunk that is out of range. Allowed range is [0, {}]".format(self.chunk_count) + "Attempting to seek to a chunk that is out of range. Allowed range is [0, {}]".format(self.chunk_count), ) self._index = index @@ -156,19 +156,17 @@ def tell(self): def __repr__(self): return "{mode} IO Object for {type} @ {location}".format( - type=self._schema.type, - location=self._schema.remote_prefix, - mode=self._mode) + type=self._schema.type, location=self._schema.remote_prefix, mode=self._mode + ) class _SchemaReader(_SchemaIO): - def __init__(self, schema_instance, local_dir): """ :param Schema schema_instance: :param flytekit.common.utils.Directory local_dir: """ - super(_SchemaReader, self).__init__(schema_instance, local_dir, 'Read-Only') + super(_SchemaReader, self).__init__(schema_instance, local_dir, "Read-Only") self.reset_chunks() @_exception_scopes.system_entry_point @@ -196,9 +194,9 @@ def _read_parquet_with_type_promotion_override(chunk, columns, parquet_engine): """ df = None - if parquet_engine == 'fastparquet': - from fastparquet import ParquetFile as _ParquetFile + if parquet_engine == "fastparquet": import fastparquet.thrift_structures as _ts + from fastparquet import ParquetFile as _ParquetFile # https://github.com/dask/fastparquet/issues/414#issuecomment-478983811 df = _pd.read_parquet(chunk, columns=columns, engine=parquet_engine, index=False) @@ -206,12 +204,12 @@ def _read_parquet_with_type_promotion_override(chunk, columns, parquet_engine): pf = _ParquetFile(chunk) schema_column_dtypes = {l.name: l.type for l in list(pf.schema.schema_elements)} - for idx in df_column_types[df_column_types == 'float16'].index.tolist(): + for idx in df_column_types[df_column_types == "float16"].index.tolist(): # A hacky way to get the string representations of the column types of a parquet schema # Reference: # https://github.com/dask/fastparquet/blob/f4ecc67f50e7bf98b2d0099c9589c615ea4b06aa/fastparquet/schema.py if _ts.parquet_thrift.Type._VALUES_TO_NAMES[schema_column_dtypes[idx]] == "BOOLEAN": - df[idx] = df[idx].astype('object') + df[idx] = df[idx].astype("object") df[idx].replace({0: False, 1: True, _pd.np.nan: None}, inplace=True) else: @@ -244,9 +242,10 @@ def read(self, columns=None, concat=False, truncate_extra_columns=True, **kwargs self._access_guard() parquet_engine = _sdk_config.PARQUET_ENGINE.get() - if parquet_engine not in {'fastparquet', 'pyarrow'}: + if parquet_engine not in {"fastparquet", "pyarrow"}: raise _user_exceptions.FlyteAssertion( - "environment variable parquet_engine must be one of 'pyarrow', 'fastparquet', or be unset") + "environment variable parquet_engine must be one of 'pyarrow', 'fastparquet', or be unset" + ) df_out = None if not columns: @@ -264,25 +263,21 @@ def read(self, columns=None, concat=False, truncate_extra_columns=True, **kwargs chunk=chunk, columns=columns, parquet_engine=parquet_engine ) # _pd.read_parquet(chunk, columns=columns, engine=parquet_engine) - for chunk in self._chunks[self._index:] + for chunk in self._chunks[self._index :] if _os.path.getsize(chunk) > 0 ] if len(frames) == 1: df_out = frames[0] elif len(frames) > 1: - df_out = _pd.concat( - frames, - copy=True) + df_out = _pd.concat(frames, copy=True) self._index = len(self._chunks) else: while self._index < len(self._chunks) and df_out is None: # Skip empty chunks so the user appears to have a continuous stream of data. if _os.path.getsize(self._chunks[self._index]) > 0: df_out = _SchemaReader._read_parquet_with_type_promotion_override( - chunk=self._chunks[self._index], - columns=columns, - parquet_engine=parquet_engine, - **kwargs) + chunk=self._chunks[self._index], columns=columns, parquet_engine=parquet_engine, **kwargs + ) self._index += 1 if df_out is not None: @@ -301,20 +296,19 @@ def read(self, columns=None, concat=False, truncate_extra_columns=True, **kwargs if len(self._schema.type.columns) > 0: # Avoid using pandas.DataFrame.rename() as this function incurs significant memory overhead df_out.columns = [ - user_column_dict[col] if col in user_columns else col - for col in df_out.columns.values] + user_column_dict[col] if col in user_columns else col for col in df_out.columns.values + ] return df_out class _SchemaWriter(_SchemaIO): - def __init__(self, schema_instance, local_dir): """ :param Schema schema_instance: :param flytekit.common.utils.Directory local_dir: :param Text mode: """ - super(_SchemaWriter, self).__init__(schema_instance, local_dir, 'Write-Only') + super(_SchemaWriter, self).__init__(schema_instance, local_dir, "Write-Only") @_exception_scopes.system_entry_point def close(self): @@ -329,7 +323,7 @@ def close(self): super(_SchemaWriter, self).close() @_exception_scopes.system_entry_point - def write(self, data_frame, coerce_timestamps='us', allow_truncated_timestamps=False): + def write(self, data_frame, coerce_timestamps="us", allow_truncated_timestamps=False): """ Writes data frame as a chunk to the local directory owned by the Schema object. Will later be uploaded to s3. @@ -346,7 +340,8 @@ def write(self, data_frame, coerce_timestamps='us', allow_truncated_timestamps=F expected_type=_pd.DataFrame, received_type=type(data_frame), received_value=data_frame, - additional_msg="Only pandas DataFrame objects can be written to a Schema object") + additional_msg="Only pandas DataFrame objects can be written to a Schema object", + ) self._schema.compare_dataframe_to_schema(data_frame) all_columns = list(data_frame.columns.values) @@ -359,9 +354,8 @@ def write(self, data_frame, coerce_timestamps='us', allow_truncated_timestamps=F try: filename = self._local_dir.get_named_tempfile(_os.path.join(str(self._index).zfill(6))) data_frame.to_parquet( - filename, - coerce_timestamps=coerce_timestamps, - allow_truncated_timestamps=allow_truncated_timestamps) + filename, coerce_timestamps=coerce_timestamps, allow_truncated_timestamps=allow_truncated_timestamps, + ) if self._index == len(self._chunks): self._chunks.append(filename) self._index += 1 @@ -371,7 +365,6 @@ def write(self, data_frame, coerce_timestamps='us', allow_truncated_timestamps=F class _SchemaBackingMpBlob(_blob_impl.MultiPartBlob): - @property def directory(self): """ @@ -388,17 +381,16 @@ def __enter__(self): "specify a path when calling this function." ) self._directory = _utils.AutoDeletingTempDir( - _uuid.uuid4().hex, - tmp_dir=_data_proxy.LocalWorkingDirectoryContext.get().name + _uuid.uuid4().hex, tmp_dir=_data_proxy.LocalWorkingDirectoryContext.get().name, ) self._is_managed = True self._directory.__enter__() - if 'r' in self.mode: + if "r" in self.mode: _data_proxy.Data.get_data(self.remote_location, self.local_path, is_multipart=True) def __exit__(self, exc_type, exc_val, exc_tb): - if 'w' in self.mode: + if "w" in self.mode: _data_proxy.Data.put_data(self.local_path, self.remote_location, is_multipart=True) return super(_SchemaBackingMpBlob, self).__exit__(exc_type, exc_val, exc_tb) @@ -432,10 +424,7 @@ def columns(self): :rtype: list[flytekit.models.types.SchemaType.SchemaColumn] """ return [ - _type_models.SchemaType.SchemaColumn( - n, - type(self)._LITERAL_TYPE_TO_PROTO_ENUM[v.to_flyte_literal_type()] - ) + _type_models.SchemaType.SchemaColumn(n, type(self)._LITERAL_TYPE_TO_PROTO_ENUM[v.to_flyte_literal_type()]) for n, v in _six.iteritems(self.sdk_columns) ] @@ -446,18 +435,24 @@ def promote_from_model(cls, model): :rtype: SchemaType """ _PROTO_ENUM_TO_SDK_TYPE = { - _type_models.SchemaType.SchemaColumn.SchemaColumnType.INTEGER: - _helpers.get_sdk_type_from_literal_type(_primitives.Integer.to_flyte_literal_type()), - _type_models.SchemaType.SchemaColumn.SchemaColumnType.FLOAT: - _helpers.get_sdk_type_from_literal_type(_primitives.Float.to_flyte_literal_type()), - _type_models.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN: - _helpers.get_sdk_type_from_literal_type(_primitives.Boolean.to_flyte_literal_type()), - _type_models.SchemaType.SchemaColumn.SchemaColumnType.DATETIME: - _helpers.get_sdk_type_from_literal_type(_primitives.Datetime.to_flyte_literal_type()), - _type_models.SchemaType.SchemaColumn.SchemaColumnType.DURATION: - _helpers.get_sdk_type_from_literal_type(_primitives.Timedelta.to_flyte_literal_type()), - _type_models.SchemaType.SchemaColumn.SchemaColumnType.STRING: - _helpers.get_sdk_type_from_literal_type(_primitives.String.to_flyte_literal_type()), + _type_models.SchemaType.SchemaColumn.SchemaColumnType.INTEGER: _helpers.get_sdk_type_from_literal_type( + _primitives.Integer.to_flyte_literal_type() + ), + _type_models.SchemaType.SchemaColumn.SchemaColumnType.FLOAT: _helpers.get_sdk_type_from_literal_type( + _primitives.Float.to_flyte_literal_type() + ), + _type_models.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN: _helpers.get_sdk_type_from_literal_type( + _primitives.Boolean.to_flyte_literal_type() + ), + _type_models.SchemaType.SchemaColumn.SchemaColumnType.DATETIME: _helpers.get_sdk_type_from_literal_type( + _primitives.Datetime.to_flyte_literal_type() + ), + _type_models.SchemaType.SchemaColumn.SchemaColumnType.DURATION: _helpers.get_sdk_type_from_literal_type( + _primitives.Timedelta.to_flyte_literal_type() + ), + _type_models.SchemaType.SchemaColumn.SchemaColumnType.STRING: _helpers.get_sdk_type_from_literal_type( + _primitives.String.to_flyte_literal_type() + ), } return cls([(c.name, _PROTO_ENUM_TO_SDK_TYPE[c.type]) for c in model.columns]) @@ -468,46 +463,55 @@ def _set_columns(self, columns): raise _user_exceptions.FlyteValueException( column, "When specifying a Schema type with a known set of columns. Each column must be " - "specified as a tuple in the form ('name', type).") + "specified as a tuple in the form ('name', type).", + ) if len(column) != 2: raise _user_exceptions.FlyteValueException( column, "When specifying a Schema type with a known set of columns. Each column must be " - "specified as a tuple in the form ('name', type).") + "specified as a tuple in the form ('name', type).", + ) name, sdk_type = column sdk_type = _helpers.python_std_to_sdk_type(sdk_type) if not isinstance(name, (str, _six.text_type)): - additional_msg = "When specifying a Schema type with a known set of columns, the first element in" \ - " each tuple must be text." + additional_msg = ( + "When specifying a Schema type with a known set of columns, the first element in" + " each tuple must be text." + ) raise _user_exceptions.FlyteTypeException( received_type=type(name), received_value=name, expected_type={str, _six.text_type}, - additional_msg=additional_msg) - - if not isinstance(sdk_type, _base_sdk_types.FlyteSdkType) or sdk_type.to_flyte_literal_type() not in \ - get_supported_literal_types_to_pandas_types(): - additional_msg = \ - "When specifying a Schema type with a known set of columns, the second element of " \ - "each tuple must be a supported type. Failed for column: {name}".format( - name=name) + additional_msg=additional_msg, + ) + + if ( + not isinstance(sdk_type, _base_sdk_types.FlyteSdkType) + or sdk_type.to_flyte_literal_type() not in get_supported_literal_types_to_pandas_types() + ): + additional_msg = ( + "When specifying a Schema type with a known set of columns, the second element of " + "each tuple must be a supported type. Failed for column: {name}".format(name=name) + ) raise _user_exceptions.FlyteTypeException( expected_type=list(get_supported_literal_types_to_pandas_types().keys()), received_type=sdk_type, - additional_msg=additional_msg) + additional_msg=additional_msg, + ) if name in names_seen: - raise ValueError("The column name {name} was specified multiple times when instantiating the " - "Schema.".format(name=name)) + raise ValueError( + "The column name {name} was specified multiple times when instantiating the " + "Schema.".format(name=name) + ) names_seen.add(name) self._sdk_columns = _collections.OrderedDict(columns) class Schema(_six.with_metaclass(_sdk_bases.ExtendedSdkType, _literal_models.Schema)): - - def __init__(self, remote_path, mode='rb', schema_type=None): + def __init__(self, remote_path, mode="rb", schema_type=None): """ :param Text remote_path: :param Text mode: @@ -515,10 +519,7 @@ def __init__(self, remote_path, mode='rb', schema_type=None): not specified, the schema will be considered generic. """ self._mp_blob = _SchemaBackingMpBlob(remote_path, mode=mode) - super(Schema, self).__init__( - self._mp_blob.uri, - schema_type or SchemaType() - ) + super(Schema, self).__init__(self._mp_blob.uri, schema_type or SchemaType()) self._io_object = None @classmethod @@ -531,7 +532,7 @@ def promote_from_model(cls, model): @classmethod @_exception_scopes.system_entry_point - def create_at_known_location(cls, known_remote_location, mode='wb', schema_type=None): + def create_at_known_location(cls, known_remote_location, mode="wb", schema_type=None): """ :param Text known_remote_location: The location to which to write the object. Usually an s3 path. :param Text mode: @@ -543,7 +544,7 @@ def create_at_known_location(cls, known_remote_location, mode='wb', schema_type= @classmethod @_exception_scopes.system_entry_point - def create_at_any_location(cls, mode='wb', schema_type=None): + def create_at_any_location(cls, mode="wb", schema_type=None): """ :param Text mode: :param SchemaType schema_type: [Optional] If specified, the schema will be forced to conform to this type. If @@ -554,7 +555,7 @@ def create_at_any_location(cls, mode='wb', schema_type=None): @classmethod @_exception_scopes.system_entry_point - def fetch(cls, remote_path, local_path=None, overwrite=False, mode='rb', schema_type=None): + def fetch(cls, remote_path, local_path=None, overwrite=False, mode="rb", schema_type=None): """ :param Text remote_path: The location from which to fetch the object. Usually an s3 path. :param Text local_path: [Optional] A local path to which to download the object. If specified, the object @@ -606,7 +607,7 @@ def from_python_std(cls, t_value, schema_type=None): type(t_value), {str, _six.text_type, Schema}, received_value=x, - additional_msg="A Schema object can only be create from a pandas DataFrame or a list of pandas DataFrame." + additional_msg="A Schema object can only be create from a pandas DataFrame or a list of pandas DataFrame.", ) return o else: @@ -614,7 +615,7 @@ def from_python_std(cls, t_value, schema_type=None): type(t_value), {str, _six.text_type, Schema}, received_value=t_value, - additional_msg="Unable to create Schema from user-provided value." + additional_msg="Unable to create Schema from user-provided value.", ) @classmethod @@ -631,12 +632,8 @@ def from_string(cls, string_value, schema_type=None): @classmethod @_exception_scopes.system_entry_point def create_from_hive_query( - cls, - select_query, - stage_query=None, - schema_to_table_name_map=None, - schema_type=None, - known_location=None): + cls, select_query, stage_query=None, schema_to_table_name_map=None, schema_type=None, known_location=None, + ): """ Returns a query that can be submitted to Hive and produce the desired output. It also returns a properly-typed schema object. @@ -652,9 +649,7 @@ def create_from_hive_query( :return: Schema, Text """ schema_object = cls( - known_location or _data_proxy.Data.get_remote_directory(), - mode='wb', - schema_type=schema_type + known_location or _data_proxy.Data.get_remote_directory(), mode="wb", schema_type=schema_type, ) if len(schema_object.type.sdk_columns) > 0: @@ -665,26 +660,31 @@ def create_from_hive_query( columnar_clauses = [] for name, sdk_type in _six.iteritems(schema_object.type.sdk_columns): if sdk_type == _primitives.Float: - columnar_clauses.append("CAST({table_column_name} as double) {schema_name}".format( - table_column_name=schema_to_table_name_map[name], - schema_name=name)) + columnar_clauses.append( + "CAST({table_column_name} as double) {schema_name}".format( + table_column_name=schema_to_table_name_map[name], schema_name=name, + ) + ) else: - columnar_clauses.append("{table_column_name} as {schema_name}".format( - table_column_name=schema_to_table_name_map[name], - schema_name=name)) + columnar_clauses.append( + "{table_column_name} as {schema_name}".format( + table_column_name=schema_to_table_name_map[name], schema_name=name, + ) + ) columnar_query = ",\n\t\t".join(columnar_clauses) else: columnar_query = "*" - stage_query_str = _six.text_type(stage_query or '') + stage_query_str = _six.text_type(stage_query or "") # the stage query should always end with a semicolon - stage_query_str = stage_query_str if stage_query_str.endswith(';') else (stage_query_str + ';') + stage_query_str = stage_query_str if stage_query_str.endswith(";") else (stage_query_str + ";") query = _HIVE_QUERY_FORMATTER.format( url=schema_object.remote_location, stage_query_str=stage_query_str, - query_str=select_query.strip().strip(';'), + query_str=select_query.strip().strip(";"), columnar_query=columnar_query, - table=_uuid.uuid4().hex) + table=_uuid.uuid4().hex, + ) return schema_object, query @property @@ -754,7 +754,7 @@ def __enter__(self): ) self._mp_blob.__enter__() - if 'r' in self.mode: + if "r" in self.mode: self._io_object = _SchemaReader(self, self.multipart_blob.directory) else: self._io_object = _SchemaWriter(self, self.multipart_blob.directory) @@ -769,7 +769,7 @@ def __repr__(self): return "Schema({columns}) @ {location} ({mode})".format( columns=self.type.columns, location=self.remote_prefix, - mode='read-only' if 'r' in self.mode else 'write-only' + mode="read-only" if "r" in self.mode else "write-only", ) @_exception_scopes.system_entry_point @@ -785,12 +785,13 @@ def download(self, local_path=None, overwrite=False): @_exception_scopes.system_entry_point def get_write_partition_to_hive_table_query( - self, - table_name, - partitions=None, - schema_to_table_name_map=None, - partitions_in_table=False, - append_to_partition=False): + self, + table_name, + partitions=None, + schema_to_table_name_map=None, + partitions_in_table=False, + append_to_partition=False, + ): """ Returns a Hive query string that will update the metatable to point to the data as the new partition. @@ -819,36 +820,41 @@ def get_write_partition_to_hive_table_query( expected_type={str, _six.text_type}, received_type=type(partition_name), received_value=partition_name, - additional_msg="All partition names must be type str.") + additional_msg="All partition names must be type str.", + ) if type(partition_value) not in _ALLOWED_PARTITION_TYPES: raise _user_exceptions.FlyteTypeException( expected_type=_ALLOWED_PARTITION_TYPES, received_type=type(partition_value), received_value=partition_value, - additional_msg="Partition {name} has an unsupported type.".format(name=partition_name) + additional_msg="Partition {name} has an unsupported type.".format(name=partition_name), ) # We need the string to be quoted in the query, so let's take repr of it. if isinstance(partition_value, (str, _six.text_type)): partition_value = repr(partition_value) - partition_conditions.append("{partition_name} = {partition_value}".format( - partition_name=partition_name, - partition_value=partition_value)) + partition_conditions.append( + "{partition_name} = {partition_value}".format( + partition_name=partition_name, partition_value=partition_value + ) + ) partition_formatter = "PARTITION (\n\t{conditions}\n)" partition_string = partition_formatter.format(conditions=",\n\t".join(partition_conditions)) if partitions_in_table and partitions: where_clauses = [] for partition_name, partition_value in partitions: - where_clauses.append("\n\t\t{schema_name} = {value_str} AND ".format( - schema_name=table_to_schema_name_map[partition_name], - value_str=partition_value - )) + where_clauses.append( + "\n\t\t{schema_name} = {value_str} AND ".format( + schema_name=table_to_schema_name_map[partition_name], value_str=partition_value, + ) + ) where_string = "WHERE\n\t\t{where_clauses}".format(where_clauses=" AND\n\t\t".join(where_clauses)) if where_string or partitions_in_table: raise _user_exceptions.FlyteAssertion( - "Currently, the partition values should not be present in the schema pushed to Hive.") + "Currently, the partition values should not be present in the schema pushed to Hive." + ) if append_to_partition: raise _user_exceptions.FlyteAssertion( "Currently, partitions can only be overwritten, they cannot be appended." @@ -859,9 +865,8 @@ def get_write_partition_to_hive_table_query( ) return _format_insert_partition_query( - remote_location=self.remote_location, - table_name=table_name, - partition_string=partition_string) + remote_location=self.remote_location, table_name=table_name, partition_string=partition_string, + ) def compare_dataframe_to_schema(self, data_frame, column_subset=None, read=False): """ @@ -891,9 +896,7 @@ def compare_dataframe_to_schema(self, data_frame, column_subset=None, read=False additional_msg = "" raise _user_exceptions.FlyteAssertion( "{} was/where requested but could not be found in the schema: {}.{}".format( - failed_columns, - self.type.sdk_columns, - additional_msg + failed_columns, self.type.sdk_columns, additional_msg ) ) @@ -902,7 +905,7 @@ def compare_dataframe_to_schema(self, data_frame, column_subset=None, read=False expected_type=self.type.sdk_columns, received_type=data_frame.columns, additional_msg="Mismatch between the data frame's column names {} and schema's column names {} " - "with strict_names=True.".format(all_columns, schema_column_names) + "with strict_names=True.".format(all_columns, schema_column_names), ) # This only iterates if the Schema has specified columns. @@ -912,25 +915,27 @@ def compare_dataframe_to_schema(self, data_frame, column_subset=None, read=False # TODO np.issubdtype is deprecated. Replace it if all( - not _np.issubdtype(dtype, allowed_type) - for allowed_type in get_supported_literal_types_to_pandas_types()[literal_type] + not _np.issubdtype(dtype, allowed_type) + for allowed_type in get_supported_literal_types_to_pandas_types()[literal_type] ): if read: read_or_write_msg = "read data frame object from schema" else: read_or_write_msg = "write data frame object to schema" - additional_msg = \ - "Cannot {read_write} because the types do not match. Column " \ - "'{name}' did not pass type checking. Note: If your " \ - "column contains null values, the types might not transition as expected between parquet and " \ - "pandas. For more information, see: " \ + additional_msg = ( + "Cannot {read_write} because the types do not match. Column " + "'{name}' did not pass type checking. Note: If your " + "column contains null values, the types might not transition as expected between parquet and " + "pandas. For more information, see: " "http://arrow.apache.org/docs/python/pandas.html#arrow-pandas-conversion".format( - read_write=read_or_write_msg, - name=name) + read_write=read_or_write_msg, name=name + ) + ) raise _user_exceptions.FlyteTypeException( expected_type=get_supported_literal_types_to_pandas_types()[literal_type], received_type=dtype, - additional_msg=additional_msg) + additional_msg=additional_msg, + ) def cast_to(self, other_type): """ @@ -944,16 +949,16 @@ def cast_to(self, other_type): self.type, other_type, additional_msg="Cannot cast because a required column '{}' was not found.".format(k), - received_value=self + received_value=self, ) - if not isinstance(v, _base_sdk_types.FlyteSdkType) or \ - v.to_flyte_literal_type() != self.type.sdk_columns[k].to_flyte_literal_type(): + if ( + not isinstance(v, _base_sdk_types.FlyteSdkType) + or v.to_flyte_literal_type() != self.type.sdk_columns[k].to_flyte_literal_type() + ): raise _user_exceptions.FlyteTypeException( self.type.sdk_columns[k], v, - additional_msg="Cannot cast because the column type for column '{}' does not match.".format( - k - ) + additional_msg="Cannot cast because the column type for column '{}' does not match.".format(k), ) return Schema(self.remote_location, mode=self.mode, schema_type=other_type) @@ -962,18 +967,16 @@ def upload(self): """ Upload the schema to the remote location """ - if 'w' not in self.mode: + if "w" not in self.mode: raise _user_exceptions.FlyteAssertion("Cannot upload a read-only schema!") elif not self.local_path: - raise _user_exceptions.FlyteAssertion("The schema is not currently backed by a local directory " - "and therefore cannot be uploaded. Please write to this before " - "attempting an upload.") + raise _user_exceptions.FlyteAssertion( + "The schema is not currently backed by a local directory " + "and therefore cannot be uploaded. Please write to this before " + "attempting an upload." + ) else: # TODO: Introduce system logging # logging.info("Putting {} -> {}".format(self.local_path, self.remote_location)) - _data_proxy.Data.put_data( - self.local_path, - self.remote_location, - is_multipart=True - ) + _data_proxy.Data.put_data(self.local_path, self.remote_location, is_multipart=True) diff --git a/flytekit/common/types/primitives.py b/flytekit/common/types/primitives.py index fe393515bb..d94350aff5 100644 --- a/flytekit/common/types/primitives.py +++ b/flytekit/common/types/primitives.py @@ -1,20 +1,21 @@ from __future__ import absolute_import +import datetime as _datetime +import json as _json + +import six as _six +from dateutil import parser as _parser +from google.protobuf import json_format as _json_format +from google.protobuf import struct_pb2 as _struct from pytimeparse import parse as _parse_duration_string from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.types import base_sdk_types as _base_sdk_types -from flytekit.models import types as _idl_types, literals as _literals -from dateutil import parser as _parser -from google.protobuf import json_format as _json_format, struct_pb2 as _struct - -import json as _json -import six as _six -import datetime as _datetime +from flytekit.models import literals as _literals +from flytekit.models import types as _idl_types class Integer(_base_sdk_types.FlyteSdkValue): - @classmethod def from_string(cls, string_value): """ @@ -24,9 +25,11 @@ def from_string(cls, string_value): try: return cls(int(string_value)) except (ValueError, TypeError): - raise _user_exceptions.FlyteTypeException(_six.text_type, int, - additional_msg='String not castable to Integer SDK type:' - ' {}'.format(string_value)) + raise _user_exceptions.FlyteTypeException( + _six.text_type, + int, + additional_msg="String not castable to Integer SDK type:" " {}".format(string_value), + ) @classmethod def is_castable_from(cls, other): @@ -70,7 +73,7 @@ def short_class_string(cls): """ :rtype: Text """ - return 'Integer' + return "Integer" def __init__(self, value): """ @@ -92,7 +95,6 @@ def short_string(self): class Float(_base_sdk_types.FlyteSdkValue): - @classmethod def from_string(cls, string_value): """ @@ -102,9 +104,11 @@ def from_string(cls, string_value): try: return cls(float(string_value)) except ValueError: - raise _user_exceptions.FlyteTypeException(_six.text_type, float, - additional_msg='String not castable to Float SDK type:' - ' {}'.format(string_value)) + raise _user_exceptions.FlyteTypeException( + _six.text_type, + float, + additional_msg="String not castable to Float SDK type:" " {}".format(string_value), + ) @classmethod def is_castable_from(cls, other): @@ -148,7 +152,7 @@ def short_class_string(cls): """ :rtype: Text """ - return 'Float' + return "Float" def __init__(self, value): """ @@ -170,20 +174,19 @@ def short_string(self): class Boolean(_base_sdk_types.FlyteSdkValue): - @classmethod def from_string(cls, string_value): """ :param Text string_value: :rtype: Boolean """ - if string_value == '1' or string_value.lower() == 'true': + if string_value == "1" or string_value.lower() == "true": return cls(True) - elif string_value == '0' or string_value.lower() == 'false': + elif string_value == "0" or string_value.lower() == "false": return cls(False) - raise _user_exceptions.FlyteTypeException(_six.text_type, bool, - additional_msg='String not castable to Boolean SDK ' - 'type: {}'.format(string_value)) + raise _user_exceptions.FlyteTypeException( + _six.text_type, bool, additional_msg="String not castable to Boolean SDK " "type: {}".format(string_value), + ) @classmethod def is_castable_from(cls, other): @@ -227,7 +230,7 @@ def short_class_string(cls): """ :rtype: Text """ - return 'Boolean' + return "Boolean" def __init__(self, value): """ @@ -249,7 +252,6 @@ def short_string(self): class String(_base_sdk_types.FlyteSdkValue): - @classmethod def from_string(cls, string_value): """ @@ -258,8 +260,10 @@ def from_string(cls, string_value): """ if type(string_value) == dict or type(string_value) == list: raise _user_exceptions.FlyteTypeException( - type(string_value), _six.text_type, - additional_msg='Should not cast native Python type to string {}'.format(string_value)) + type(string_value), + _six.text_type, + additional_msg="Should not cast native Python type to string {}".format(string_value), + ) return cls(string_value) @classmethod @@ -305,7 +309,7 @@ def short_class_string(cls): """ :rtype: Text """ - return 'String' + return "String" def __init__(self, value): """ @@ -326,7 +330,7 @@ def short_string(self): _TRUNCATE_LENGTH = 100 return "String('{}'{})".format( self.scalar.primitive.string_value[:_TRUNCATE_LENGTH], - " ..." if len(self.scalar.primitive.string_value) > _TRUNCATE_LENGTH else "" + " ..." if len(self.scalar.primitive.string_value) > _TRUNCATE_LENGTH else "", ) def verbose_string(self): @@ -337,7 +341,6 @@ def verbose_string(self): class Datetime(_base_sdk_types.FlyteSdkValue): - @classmethod def from_string(cls, string_value): """ @@ -347,9 +350,11 @@ def from_string(cls, string_value): try: python_std_datetime = _parser.parse(string_value) except ValueError: - raise _user_exceptions.FlyteTypeException(_six.text_type, _datetime.datetime, - additional_msg='String not castable to Datetime ' - 'SDK type: {}'.format(string_value)) + raise _user_exceptions.FlyteTypeException( + _six.text_type, + _datetime.datetime, + additional_msg="String not castable to Datetime " "SDK type: {}".format(string_value), + ) return cls.from_python_std(python_std_datetime) @@ -373,8 +378,9 @@ def from_python_std(cls, t_value): elif type(t_value) != _datetime.datetime: raise _user_exceptions.FlyteTypeException(type(t_value), _datetime.datetime, t_value) elif t_value.tzinfo is None: - raise _user_exceptions.FlyteValueException(t_value, "Datetime objects in Flyte must be timezone aware. " - "tzinfo was found to be None.") + raise _user_exceptions.FlyteValueException( + t_value, "Datetime objects in Flyte must be timezone aware. " "tzinfo was found to be None.", + ) return cls(t_value) @classmethod @@ -398,7 +404,7 @@ def short_class_string(cls): """ :rtype: Text """ - return 'Datetime' + return "Datetime" def __init__(self, value): """ @@ -420,7 +426,6 @@ def short_string(self): class Timedelta(_base_sdk_types.FlyteSdkValue): - @classmethod def from_string(cls, string_value): """ @@ -429,9 +434,11 @@ def from_string(cls, string_value): """ td = _parse_duration_string(string_value) if td is None: - raise _user_exceptions.FlyteTypeException(_six.text_type, _datetime.timedelta, - additional_msg='Could not convert string to' - ' time delta: {}'.format(string_value)) + raise _user_exceptions.FlyteTypeException( + _six.text_type, + _datetime.timedelta, + additional_msg="Could not convert string to" " time delta: {}".format(string_value), + ) return cls.from_python_std(_datetime.timedelta(seconds=td)) @classmethod @@ -477,7 +484,7 @@ def short_class_string(cls): """ :rtype: Text """ - return 'Timedelta' + return "Timedelta" def __init__(self, value): """ @@ -499,7 +506,6 @@ def short_string(self): class Generic(_base_sdk_types.FlyteSdkValue): - @classmethod def from_string(cls, string_value): """ @@ -508,11 +514,8 @@ def from_string(cls, string_value): """ try: t = _json_format.Parse(string_value, _struct.Struct()) - except: - raise _user_exceptions.FlyteValueException( - string_value, - "Could not be parsed from JSON." - ) + except Exception: + raise _user_exceptions.FlyteValueException(string_value, "Could not be parsed from JSON.") return cls(t) @classmethod @@ -537,11 +540,8 @@ def from_python_std(cls, t_value): try: t = _json.dumps(t_value) - except: - raise _user_exceptions.FlyteValueException( - t_value, - "Is not JSON serializable." - ) + except Exception: + raise _user_exceptions.FlyteValueException(t_value, "Is not JSON serializable.") return cls(_json_format.Parse(t, _struct.Struct())) @@ -566,7 +566,7 @@ def short_class_string(cls): """ :rtype: Text """ - return 'Generic' + return "Generic" def __init__(self, value): """ diff --git a/flytekit/common/types/proto.py b/flytekit/common/types/proto.py index fbae7d7086..e7a0562c0b 100644 --- a/flytekit/common/types/proto.py +++ b/flytekit/common/types/proto.py @@ -1,12 +1,14 @@ from __future__ import absolute_import -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.types import base_sdk_types as _base_sdk_types -from flytekit.models import types as _idl_types, literals as _literals -from google.protobuf import reflection as _proto_reflection - import base64 as _base64 + import six as _six +from google.protobuf import reflection as _proto_reflection + +from flytekit.common.exceptions import user as _user_exceptions +from flytekit.common.types import base_sdk_types as _base_sdk_types +from flytekit.models import literals as _literals +from flytekit.models import types as _idl_types def create_protobuf(pb_type): @@ -18,16 +20,16 @@ def create_protobuf(pb_type): raise _user_exceptions.FlyteTypeException( expected_type=_proto_reflection.GeneratedProtocolMessageType, received_type=type(pb_type), - received_value=pb_type + received_value=pb_type, ) class _Protobuf(Protobuf): _pb_type = pb_type + return _Protobuf class ProtobufType(_base_sdk_types.FlyteSdkType): - @property def pb_type(cls): """ @@ -62,10 +64,7 @@ def __init__(self, pb_object): data = pb_object.SerializeToString() super(Protobuf, self).__init__( scalar=_literals.Scalar( - binary=_literals.Binary( - value=bytes(data) if _six.PY2 else data, - tag=type(self).tag - ) + binary=_literals.Binary(value=bytes(data) if _six.PY2 else data, tag=type(self).tag) ) ) @@ -103,23 +102,14 @@ def from_python_std(cls, t_value): elif isinstance(t_value, cls.pb_type): return cls(t_value) else: - raise _user_exceptions.FlyteTypeException( - type(t_value), - cls.pb_type, - received_value=t_value - ) + raise _user_exceptions.FlyteTypeException(type(t_value), cls.pb_type, received_value=t_value) @classmethod def to_flyte_literal_type(cls): """ :rtype: flytekit.models.types.LiteralType """ - return _idl_types.LiteralType( - simple=_idl_types.SimpleType.BINARY, - metadata={ - cls.PB_FIELD_KEY: cls.descriptor - } - ) + return _idl_types.LiteralType(simple=_idl_types.SimpleType.BINARY, metadata={cls.PB_FIELD_KEY: cls.descriptor},) @classmethod def promote_from_model(cls, literal_model): @@ -133,7 +123,7 @@ def promote_from_model(cls, literal_model): literal_model.scalar.binary.tag, cls.pb_type, received_value=_base64.b64encode(literal_model.scalar.binary.value), - additional_msg="Can not deserialize as proto tags don't match." + additional_msg="Can not deserialize as proto tags don't match.", ) pb_obj = cls.pb_type() pb_obj.ParseFromString(literal_model.scalar.binary.value) diff --git a/flytekit/common/types/schema.py b/flytekit/common/types/schema.py index 14f24e08ca..0c16f71869 100644 --- a/flytekit/common/types/schema.py +++ b/flytekit/common/types/schema.py @@ -1,21 +1,21 @@ from __future__ import absolute_import +import six as _six + from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.types import base_sdk_types as _base_sdk_types from flytekit.common.types.impl import schema as _schema_impl -from flytekit.models import types as _idl_types, literals as _literals - -import six as _six +from flytekit.models import literals as _literals +from flytekit.models import types as _idl_types class SchemaInstantiator(_base_sdk_types.InstantiableType): - def create_at_known_location(cls, location): """ :param Text location: :rtype: flytekit.common.types.impl.schema.Schema """ - return _schema_impl.Schema.create_at_known_location(location, mode='wb', schema_type=cls.schema_type) + return _schema_impl.Schema.create_at_known_location(location, mode="wb", schema_type=cls.schema_type) def fetch(cls, remote_path, local_path=None): """ @@ -24,20 +24,16 @@ def fetch(cls, remote_path, local_path=None): this location is NOT managed and the schema will not be cleaned up upon exit. :rtype: flytekit.common.types.impl.schema.Schema """ - return _schema_impl.Schema.fetch(remote_path, mode='rb', local_path=local_path, schema_type=cls.schema_type) + return _schema_impl.Schema.fetch(remote_path, mode="rb", local_path=local_path, schema_type=cls.schema_type) def create(cls): """ :rtype: flytekit.common.types.impl.schema.Schema """ - return _schema_impl.Schema.create_at_any_location(mode='wb', schema_type=cls.schema_type) + return _schema_impl.Schema.create_at_any_location(mode="wb", schema_type=cls.schema_type) def create_from_hive_query( - cls, - select_query, - stage_query=None, - schema_to_table_name_map=None, - known_location=None + cls, select_query, stage_query=None, schema_to_table_name_map=None, known_location=None, ): """ Returns a query that can be submitted to Hive and produce the desired output. It also returns a properly-typed @@ -56,7 +52,7 @@ def create_from_hive_query( stage_query=stage_query, schema_to_table_name_map=schema_to_table_name_map, known_location=known_location, - schema_type=cls.schema_type + schema_type=cls.schema_type, ) def __call__(cls, *args, **kwargs): @@ -66,7 +62,7 @@ def __call__(cls, *args, **kwargs): :rtype: flytekit.common.types.impl.schema.Schema """ if not args and not kwargs: - return _schema_impl.Schema.create_at_any_location(mode='wb', schema_type=cls.schema_type) + return _schema_impl.Schema.create_at_any_location(mode="wb", schema_type=cls.schema_type) else: return super(SchemaInstantiator, cls).__call__(*args, **kwargs) @@ -86,7 +82,6 @@ def columns(cls): class Schema(_six.with_metaclass(SchemaInstantiator, _base_sdk_types.FlyteSdkValue)): - @classmethod def from_string(cls, string_value): """ @@ -159,9 +154,7 @@ def short_string(self): """ :rtype: Text """ - return "{}".format( - self.scalar.schema, - ) + return "{}".format(self.scalar.schema,) def schema_instantiator(columns=None): @@ -173,11 +166,12 @@ def schema_instantiator(columns=None): if columns is not None and len(columns) == 0: raise _user_exceptions.FlyteValueException( columns, - "When specifying a Schema type with a known set of columns, a non-empty list must be provided as " - "inputs") + "When specifying a Schema type with a known set of columns, a non-empty list must be provided as " "inputs", + ) class _Schema(_six.with_metaclass(SchemaInstantiator, Schema)): _schema_type = _schema_impl.SchemaType(columns=columns) + return _Schema @@ -186,6 +180,8 @@ def schema_instantiator_from_proto(schema_type): :param flytekit.models.types.SchemaType schema_type: :rtype: SchemaInstantiator """ + class _Schema(_six.with_metaclass(SchemaInstantiator, Schema)): _schema_type = _schema_impl.SchemaType.promote_from_model(schema_type) + return _Schema diff --git a/flytekit/common/utils.py b/flytekit/common/utils.py index 3d1dfb46ea..92737b4a1d 100644 --- a/flytekit/common/utils.py +++ b/flytekit/common/utils.py @@ -3,19 +3,17 @@ import logging as _logging import os as _os import shutil as _shutil -from hashlib import sha224 as _sha224 import tempfile as _tempfile import time as _time +from hashlib import sha224 as _sha224 +from pathlib import Path import flytekit as _flytekit from flytekit.configuration import sdk as _sdk_config from flytekit.models.core import identifier as _identifier -from pathlib import Path - -def _dnsify(value): - # type: (Text) -> Text +def _dnsify(value): # type: (str) -> str """ Converts value into a DNS-compliant (RFC1035/RFC1123 DNS_LABEL). The resulting string must only consist of alphanumeric (lower-case a-z, and 0-9) and not exceed 63 characters. It's permitted to have '-' character as long @@ -28,10 +26,10 @@ def _dnsify(value): MAX = 63 HASH_LEN = 10 if len(value) >= MAX: - h = _sha224(value.encode('utf-8')).hexdigest()[:HASH_LEN] - value = "{}-{}".format(h, value[-(MAX - HASH_LEN - 1):]) + h = _sha224(value.encode("utf-8")).hexdigest()[:HASH_LEN] + value = "{}-{}".format(h, value[-(MAX - HASH_LEN - 1) :]) for ch in value: - if ch == '_' or ch == '-' or ch == '.': + if ch == "_" or ch == "-" or ch == ".": # Convert '_' to '-' unless it's the first character, in which case we drop it. if res != "" and len(res) < 62: res += "-" @@ -43,18 +41,18 @@ def _dnsify(value): res += ch else: # Character is upper-case. Add a '-' before it for better readability. - if res != "" and res[-1] != '-' and len(res) < 62: + if res != "" and res[-1] != "-" and len(res) < 62: res += "-" res += ch.lower() - if res[-1] == '-': - res = res[:len(res) - 1] + if res[-1] == "-": + res = res[: len(res) - 1] return res def load_proto_from_file(pb2_type, path): - with open(path, 'rb') as reader: + with open(path, "rb") as reader: out = pb2_type() out.ParseFromString(reader.read()) return out @@ -62,7 +60,7 @@ def load_proto_from_file(pb2_type, path): def write_proto_to_file(proto, path): Path(_os.path.dirname(path)).mkdir(parents=True, exist_ok=True) - with open(path, 'wb') as writer: + with open(path, "wb") as writer: writer.write(proto.SerializeToString()) @@ -71,7 +69,6 @@ def get_version_message(): class Directory(object): - def __init__(self, path): """ :param Text path: local path of directory @@ -111,7 +108,7 @@ def __init__(self, working_dir_prefix=None, tmp_dir=None, cleanup=True): :param bool cleanup: Whether the directory should be cleaned up upon exit """ self._tmp_dir = tmp_dir - self._working_dir_prefix = (working_dir_prefix + "_") if working_dir_prefix else '' + self._working_dir_prefix = (working_dir_prefix + "_") if working_dir_prefix else "" self._cleanup = cleanup super(AutoDeletingTempDir, self).__init__(None) @@ -142,7 +139,6 @@ def __str__(self): class PerformanceTimer(object): - def __init__(self, context_statement): """ :param Text context_statement: the statement to log @@ -159,15 +155,16 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): end_wall_time = _time.perf_counter() end_process_time = _time.process_time() - _logging.info("Exiting timed context: {} [Wall Time: {}s, Process Time: {}s]".format( - self._context_statement, - end_wall_time - self._start_wall_time, - end_process_time - self._start_process_time - )) + _logging.info( + "Exiting timed context: {} [Wall Time: {}s, Process Time: {}s]".format( + self._context_statement, + end_wall_time - self._start_wall_time, + end_process_time - self._start_process_time, + ) + ) class ExitStack(object): - def __init__(self, entered_stack=None): self._contexts = entered_stack diff --git a/flytekit/common/workflow.py b/flytekit/common/workflow.py index aca0a1486f..522baeb387 100644 --- a/flytekit/common/workflow.py +++ b/flytekit/common/workflow.py @@ -1,28 +1,35 @@ from __future__ import absolute_import +import datetime as _datetime import uuid as _uuid import six as _six from six.moves import queue as _queue -import datetime as _datetime -from flytekit.common import interface as _interface, nodes as _nodes, sdk_bases as _sdk_bases, \ - launch_plan as _launch_plan, promise as _promise +from flytekit.common import constants as _constants +from flytekit.common import interface as _interface +from flytekit.common import launch_plan as _launch_plan +from flytekit.common import nodes as _nodes +from flytekit.common import promise as _promise +from flytekit.common import sdk_bases as _sdk_bases from flytekit.common.core import identifier as _identifier -from flytekit.common.exceptions import scopes as _exception_scopes, user as _user_exceptions -from flytekit.common.mixins import registerable as _registerable, hash as _hash_mixin +from flytekit.common.exceptions import scopes as _exception_scopes +from flytekit.common.exceptions import system as _system_exceptions +from flytekit.common.exceptions import user as _user_exceptions +from flytekit.common.mixins import hash as _hash_mixin +from flytekit.common.mixins import registerable as _registerable from flytekit.common.types import helpers as _type_helpers from flytekit.configuration import internal as _internal_config from flytekit.engines import loader as _engine_loader -from flytekit.models import interface as _interface_models, literals as _literal_models, common as _common_models -from flytekit.models.core import workflow as _workflow_models, identifier as _identifier_model -from flytekit.common.exceptions import system as _system_exceptions -from flytekit.common import constants as _constants +from flytekit.models import common as _common_models +from flytekit.models import interface as _interface_models +from flytekit.models import literals as _literal_models from flytekit.models.admin import workflow as _admin_workflow_model +from flytekit.models.core import identifier as _identifier_model +from flytekit.models.core import workflow as _workflow_models class Output(object): - def __init__(self, name, value, sdk_type=None, help=None): """ :param Text name: @@ -40,7 +47,7 @@ def __init__(self, name, value, sdk_type=None, help=None): sdk_type = _type_helpers.python_std_to_sdk_type(sdk_type) self._binding_data = _interface.BindingData.from_python_std(sdk_type.to_flyte_literal_type(), value) - self._var = _interface_models.Variable(sdk_type.to_flyte_literal_type(), help or '') + self._var = _interface_models.Variable(sdk_type.to_flyte_literal_type(), help or "") self._name = name def rename_and_return_reference(self, new_name): @@ -50,8 +57,10 @@ def rename_and_return_reference(self, new_name): @staticmethod def _infer_type(value): # TODO: Infer types - raise NotImplementedError("Currently the SDK cannot infer a workflow output type, so please use the type kwarg " - "when instantiating an output.") + raise NotImplementedError( + "Currently the SDK cannot infer a workflow output type, so please use the type kwarg " + "when instantiating an output." + ) @property def name(self): @@ -83,8 +92,17 @@ class SdkWorkflow( _registerable.RegisterableEntity, ) ): - - def __init__(self, inputs, outputs, nodes, id=None, metadata=None, metadata_defaults=None, interface=None, output_bindings=None): + def __init__( + self, + inputs, + outputs, + nodes, + id=None, + metadata=None, + metadata_defaults=None, + interface=None, + output_bindings=None, + ): """ :param list[flytekit.common.promise.Input] inputs: :param list[Output] outputs: @@ -111,22 +129,30 @@ def __init__(self, inputs, outputs, nodes, id=None, metadata=None, metadata_defa ) # Allow overrides if specified for all the arguments to the parent class constructor - id = id if id is not None else _identifier.Identifier( - _identifier_model.ResourceType.WORKFLOW, - _internal_config.PROJECT.get(), - _internal_config.DOMAIN.get(), - _uuid.uuid4().hex, - _internal_config.VERSION.get() + id = ( + id + if id is not None + else _identifier.Identifier( + _identifier_model.ResourceType.WORKFLOW, + _internal_config.PROJECT.get(), + _internal_config.DOMAIN.get(), + _uuid.uuid4().hex, + _internal_config.VERSION.get(), + ) ) metadata = metadata if metadata is not None else _workflow_models.WorkflowMetadata() - interface = interface if interface is not None else _interface.TypedInterface( - {v.name: v.var for v in inputs}, - {v.name: v.var for v in outputs} + interface = ( + interface + if interface is not None + else _interface.TypedInterface({v.name: v.var for v in inputs}, {v.name: v.var for v in outputs}) ) - output_bindings = output_bindings if output_bindings is not None else \ - [_literal_models.Binding(v.name, v.binding_data) for v in outputs] + output_bindings = ( + output_bindings + if output_bindings is not None + else [_literal_models.Binding(v.name, v.binding_data) for v in outputs] + ) super(SdkWorkflow, self).__init__( id=id, @@ -185,13 +211,14 @@ def get_sub_workflows(self): result = [] for n in self.nodes: if n.workflow_node is not None and n.workflow_node.sub_workflow_ref is not None: - if n.executable_sdk_object is not None and n.executable_sdk_object.entity_type_text == 'Workflow': + if n.executable_sdk_object is not None and n.executable_sdk_object.entity_type_text == "Workflow": result.append(n.executable_sdk_object) result.extend(n.executable_sdk_object.get_sub_workflows()) else: raise _system_exceptions.FlyteSystemException( "workflow node with subworkflow found but bad executable " - "object {}".format(n.executable_sdk_object)) + "object {}".format(n.executable_sdk_object) + ) # Ignore other node types (branch, task) return result @@ -241,8 +268,9 @@ def promote_from_model(cls, base_model, sub_workflows=None, tasks=None): base_model_non_system_nodes = cls.get_non_system_nodes(base_model.nodes) sub_workflows = sub_workflows or {} tasks = tasks or {} - node_map = {n.id: _nodes.SdkNode.promote_from_model(n, sub_workflows, tasks) - for n in base_model_non_system_nodes} + node_map = { + n.id: _nodes.SdkNode.promote_from_model(n, sub_workflows, tasks) for n in base_model_non_system_nodes + } # Set upstream nodes for each node for n in base_model_non_system_nodes: @@ -253,7 +281,9 @@ def promote_from_model(cls, base_model, sub_workflows=None, tasks=None): # No inputs/outputs specified, see the constructor for more information on the overrides. return cls( - inputs=None, outputs=None, nodes=list(node_map.values()), + inputs=None, + outputs=None, + nodes=list(node_map.values()), id=_identifier.Identifier.promote_from_model(base_model.id), metadata=base_model.metadata, metadata_defaults=base_model.metadata_defaults, @@ -270,19 +300,13 @@ def register(self, project, domain, name, version): :param Text version: """ self.validate() - id_to_register = _identifier.Identifier( - _identifier_model.ResourceType.WORKFLOW, - project, - domain, - name, - version - ) + id_to_register = _identifier.Identifier(_identifier_model.ResourceType.WORKFLOW, project, domain, name, version) old_id = self.id try: self._id = id_to_register _engine_loader.get_engine().get_workflow(self).register(id_to_register) return _six.text_type(self.id) - except: + except Exception: self._id = old_id raise @@ -295,10 +319,7 @@ def serialize(self): :rtype: flyteidl.admin.workflow_pb2.WorkflowSpec """ sub_workflows = self.get_sub_workflows() - return _admin_workflow_model.WorkflowSpec( - self, - sub_workflows, - ).to_flyte_idl() + return _admin_workflow_model.WorkflowSpec(self, sub_workflows,).to_flyte_idl() @_exception_scopes.system_entry_point def validate(self): @@ -306,18 +327,18 @@ def validate(self): @_exception_scopes.system_entry_point def create_launch_plan( - self, - default_inputs=None, - fixed_inputs=None, - schedule=None, - role=None, - notifications=None, - labels=None, - annotations=None, - assumable_iam_role=None, - kubernetes_service_account=None, - raw_output_data_prefix=None, - cls=None + self, + default_inputs=None, + fixed_inputs=None, + schedule=None, + role=None, + notifications=None, + labels=None, + annotations=None, + assumable_iam_role=None, + kubernetes_service_account=None, + raw_output_data_prefix=None, + cls=None, ): """ This method will create a launch plan object that can execute this workflow. @@ -346,16 +367,16 @@ class provided should be a subclass of flytekit.common.launch_plan.SdkLaunchPlan if role: assumable_iam_role = role - auth_role = _common_models.AuthRole(assumable_iam_role=assumable_iam_role, - kubernetes_service_account=kubernetes_service_account) + auth_role = _common_models.AuthRole( + assumable_iam_role=assumable_iam_role, kubernetes_service_account=kubernetes_service_account, + ) raw_output_config = _common_models.RawOutputDataConfig(raw_output_data_prefix or "") return (cls or _launch_plan.SdkRunnableLaunchPlan)( sdk_workflow=self, default_inputs={ - k: user_input.rename_and_return_reference(k) - for k, user_input in _six.iteritems(merged_default_inputs) + k: user_input.rename_and_return_reference(k) for k, user_input in _six.iteritems(merged_default_inputs) }, fixed_inputs=fixed_inputs, schedule=schedule, @@ -375,21 +396,19 @@ def __call__(self, *args, **input_map): ) # Take the default values from the Inputs - compiled_inputs = { - v.name: v.sdk_default - for v in self.user_inputs if not v.sdk_required - } + compiled_inputs = {v.name: v.sdk_default for v in self.user_inputs if not v.sdk_required} compiled_inputs.update(input_map) bindings, upstream_nodes = self.interface.create_bindings_for_inputs(compiled_inputs) node = _nodes.SdkNode( id=None, - metadata=_workflow_models.NodeMetadata("placeholder", _datetime.timedelta(), - _literal_models.RetryStrategy(0)), + metadata=_workflow_models.NodeMetadata( + "placeholder", _datetime.timedelta(), _literal_models.RetryStrategy(0) + ), upstream_nodes=upstream_nodes, bindings=sorted(bindings, key=lambda b: b.var), - sdk_workflow=self + sdk_workflow=self, ) return node @@ -434,30 +453,25 @@ def _discover_workflow_components(workflow_class): elif isinstance(current_obj, _promise.Input): if attribute_name is None or attribute_name not in top_level_attributes: raise _user_exceptions.FlyteValueException( - attribute_name, - "Detected workflow input specified outside of top level." + attribute_name, "Detected workflow input specified outside of top level.", ) inputs.append(current_obj.rename_and_return_reference(attribute_name)) elif isinstance(current_obj, Output): if attribute_name is None or attribute_name not in top_level_attributes: raise _user_exceptions.FlyteValueException( - attribute_name, - "Detected workflow output specified outside of top level." + attribute_name, "Detected workflow output specified outside of top level.", ) outputs.append(current_obj.rename_and_return_reference(attribute_name)) elif isinstance(current_obj, list) or isinstance(current_obj, set) or isinstance(current_obj, tuple): for idx, value in enumerate(current_obj): - to_visit_objs.put( - (_assign_indexed_attribute_name(attribute_name, idx), value)) + to_visit_objs.put((_assign_indexed_attribute_name(attribute_name, idx), value)) elif isinstance(current_obj, dict): # Visit dictionary keys. for key in current_obj.keys(): - to_visit_objs.put( - (_assign_indexed_attribute_name(attribute_name, key), key)) + to_visit_objs.put((_assign_indexed_attribute_name(attribute_name, key), key)) # Visit dictionary values. for key, value in _six.iteritems(current_obj): - to_visit_objs.put( - (_assign_indexed_attribute_name(attribute_name, key), value)) + to_visit_objs.put((_assign_indexed_attribute_name(attribute_name, key), value)) return inputs, outputs, nodes @@ -475,5 +489,5 @@ def build_sdk_workflow_from_metaclass(metaclass, on_failure=None, cls=None): inputs=[i for i in sorted(inputs, key=lambda x: x.name)], outputs=[o for o in sorted(outputs, key=lambda x: x.name)], nodes=[n for n in sorted(nodes, key=lambda x: x.id)], - metadata=metadata + metadata=metadata, ) diff --git a/flytekit/common/workflow_execution.py b/flytekit/common/workflow_execution.py index 09298fae12..1af0020d05 100644 --- a/flytekit/common/workflow_execution.py +++ b/flytekit/common/workflow_execution.py @@ -1,7 +1,9 @@ from __future__ import absolute_import import six as _six -from flytekit.common import sdk_bases as _sdk_bases, nodes as _nodes + +from flytekit.common import nodes as _nodes +from flytekit.common import sdk_bases as _sdk_bases from flytekit.common.core import identifier as _core_identifier from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.mixins import artifact as _artifact @@ -12,11 +14,7 @@ class SdkWorkflowExecution( - _six.with_metaclass( - _sdk_bases.ExtendedSdkType, - _execution_models.Execution, - _artifact.ExecutionArtifact - ) + _six.with_metaclass(_sdk_bases.ExtendedSdkType, _execution_models.Execution, _artifact.ExecutionArtifact,) ): def __init__(self, *args, **kwargs): super(SdkWorkflowExecution, self).__init__(*args, **kwargs) @@ -51,8 +49,9 @@ def outputs(self): :rtype: dict[Text, T] or None """ if not self.is_complete: - raise _user_exceptions.FlyteAssertion("Please what until the node execution has completed before " - "requesting the outputs.") + raise _user_exceptions.FlyteAssertion( + "Please what until the node execution has completed before " "requesting the outputs." + ) if self.error: raise _user_exceptions.FlyteAssertion("Outputs could not be found because the execution ended in failure.") @@ -70,8 +69,9 @@ def error(self): :rtype: flytekit.models.core.execution.ExecutionError or None """ if not self.is_complete: - raise _user_exceptions.FlyteAssertion("Please wait until a workflow has completed before checking for an " - "error.") + raise _user_exceptions.FlyteAssertion( + "Please wait until a workflow has completed before checking for an " "error." + ) return self.closure.error @property @@ -109,11 +109,7 @@ def fetch(cls, project, domain, name): """ return cls.promote_from_model( _engine_loader.get_engine().fetch_workflow_execution( - _core_identifier.WorkflowExecutionIdentifier( - project=project, - domain=domain, - name=name - ) + _core_identifier.WorkflowExecutionIdentifier(project=project, domain=domain, name=name) ) ) @@ -140,10 +136,7 @@ def get_node_executions(self, filters=None): :rtype: dict[Text, flytekit.common.nodes.SdkNodeExecution] """ models = _engine_loader.get_engine().get_workflow_execution(self).get_node_executions(filters=filters) - return { - k: _nodes.SdkNodeExecution.promote_from_model(v) - for k, v in _six.iteritems(models) - } + return {k: _nodes.SdkNodeExecution.promote_from_model(v) for k, v in _six.iteritems(models)} def terminate(self, cause): """ diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py index 2b9b747e82..2df8c7d4f4 100644 --- a/flytekit/configuration/__init__.py +++ b/flytekit/configuration/__init__.py @@ -1,12 +1,14 @@ from __future__ import absolute_import + import logging as _logging import os as _os + import six as _six try: import pathlib as _pathlib except ImportError: - import pathlib2 as _pathlib # python 2 backport + import pathlib2 as _pathlib # python 2 backport def set_flyte_config_file(config_file_path): @@ -15,6 +17,7 @@ def set_flyte_config_file(config_file_path): """ import flytekit.configuration.common as _common import flytekit.configuration.internal as _internal + if config_file_path is not None: config_file_path = _os.path.abspath(config_file_path) if not _pathlib.Path(config_file_path).is_file(): @@ -27,15 +30,14 @@ def set_flyte_config_file(config_file_path): class TemporaryConfiguration(object): - def __init__(self, new_config_path, internal_overrides=None): """ :param Text new_config_path: """ import flytekit.configuration.common as _common + self._internal_overrides = { - _common.format_section_key('internal', k): v - for k, v in _six.iteritems(internal_overrides or {}) + _common.format_section_key("internal", k): v for k, v in _six.iteritems(internal_overrides or {}) } self._new_config_path = new_config_path self._old_config_path = None @@ -44,10 +46,7 @@ def __init__(self, new_config_path, internal_overrides=None): def __enter__(self): import flytekit.configuration.internal as _internal - self._old_internals = { - k: _os.environ.get(k) - for k in _six.iterkeys(self._internal_overrides) - } + self._old_internals = {k: _os.environ.get(k) for k in _six.iterkeys(self._internal_overrides)} self._old_config_path = _os.environ.get(_internal.CONFIGURATION_PATH.env_var) _os.environ.update(self._internal_overrides) set_flyte_config_file(self._new_config_path) diff --git a/flytekit/configuration/auth.py b/flytekit/configuration/auth.py index 407fe88165..a7d1d9917c 100644 --- a/flytekit/configuration/auth.py +++ b/flytekit/configuration/auth.py @@ -2,19 +2,20 @@ from flytekit.configuration import common as _config_common -ASSUMABLE_IAM_ROLE = _config_common.FlyteStringConfigurationEntry('auth', 'assumable_iam_role', default=None) +ASSUMABLE_IAM_ROLE = _config_common.FlyteStringConfigurationEntry("auth", "assumable_iam_role", default=None) """ This is the role the SDK will use by default to execute workflows. For example, in AWS this should be an IAM role string. """ KUBERNETES_SERVICE_ACCOUNT = _config_common.FlyteStringConfigurationEntry( - 'auth', 'kubernetes_service_account', default=None) + "auth", "kubernetes_service_account", default=None +) """ This is the kubernetes service account that will be passed to workflow executions. """ -RAW_OUTPUT_DATA_PREFIX = _config_common.FlyteStringConfigurationEntry('auth', 'raw_output_data_prefix', default='') +RAW_OUTPUT_DATA_PREFIX = _config_common.FlyteStringConfigurationEntry("auth", "raw_output_data_prefix", default="") """ This is not output metadata but rather where users can specify an S3 or gcs path for offloaded data like blobs and schemas. diff --git a/flytekit/configuration/aws.py b/flytekit/configuration/aws.py index cd6a9b86d8..2dbb51c7aa 100644 --- a/flytekit/configuration/aws.py +++ b/flytekit/configuration/aws.py @@ -2,18 +2,18 @@ from flytekit.configuration import common as _config_common -S3_SHARD_FORMATTER = _config_common.FlyteRequiredStringConfigurationEntry('aws', 's3_shard_formatter') +S3_SHARD_FORMATTER = _config_common.FlyteRequiredStringConfigurationEntry("aws", "s3_shard_formatter") -S3_SHARD_STRING_LENGTH = _config_common.FlyteIntegerConfigurationEntry('aws', 's3_shard_string_length', default=2) +S3_SHARD_STRING_LENGTH = _config_common.FlyteIntegerConfigurationEntry("aws", "s3_shard_string_length", default=2) -S3_ENDPOINT = _config_common.FlyteStringConfigurationEntry('aws', 'endpoint', default=None) +S3_ENDPOINT = _config_common.FlyteStringConfigurationEntry("aws", "endpoint", default=None) -S3_ACCESS_KEY_ID = _config_common.FlyteStringConfigurationEntry('aws', 'access_key_id', default=None) +S3_ACCESS_KEY_ID = _config_common.FlyteStringConfigurationEntry("aws", "access_key_id", default=None) -S3_SECRET_ACCESS_KEY = _config_common.FlyteStringConfigurationEntry('aws', 'secret_access_key', default=None) +S3_SECRET_ACCESS_KEY = _config_common.FlyteStringConfigurationEntry("aws", "secret_access_key", default=None) -S3_ACCESS_KEY_ID_ENV_NAME = 'AWS_ACCESS_KEY_ID' +S3_ACCESS_KEY_ID_ENV_NAME = "AWS_ACCESS_KEY_ID" -S3_SECRET_ACCESS_KEY_ENV_NAME = 'AWS_SECRET_ACCESS_KEY' +S3_SECRET_ACCESS_KEY_ENV_NAME = "AWS_SECRET_ACCESS_KEY" -S3_ENDPOINT_ARG_NAME = '--endpoint-url' +S3_ENDPOINT_ARG_NAME = "--endpoint-url" diff --git a/flytekit/configuration/common.py b/flytekit/configuration/common.py index 54d9ca7ee8..55cab24010 100644 --- a/flytekit/configuration/common.py +++ b/flytekit/configuration/common.py @@ -1,9 +1,9 @@ from __future__ import absolute_import import abc as _abc +import configparser as _configparser import os as _os -import configparser as _configparser import six as _six from flytekit.common.exceptions import user as _user_exceptions @@ -15,11 +15,10 @@ def format_section_key(section, key): :param Text key: :rtype: Text """ - return 'FLYTE_{section}_{key}'.format(section=section.upper(), key=key.upper()) + return "FLYTE_{section}_{key}".format(section=section.upper(), key=key.upper()) class FlyteConfigurationFile(object): - def __init__(self, location=None): """ This singleton is initialized on module load with empty location. If pyflyte is called with @@ -35,9 +34,11 @@ def _load_config(self): if self._config is None and self._location: config = _configparser.ConfigParser() config.read(self._location) - if config.has_section('internal'): - raise _user_exceptions.FlyteAssertion("The config file '{}' cannot contain a section for internal " - "only configurations.".format(self._location)) + if config.has_section("internal"): + raise _user_exceptions.FlyteAssertion( + "The config file '{}' cannot contain a section for internal " + "only configurations.".format(self._location) + ) self._config = config def get_string(self, section, key, default=None): @@ -51,7 +52,7 @@ def get_string(self, section, key, default=None): if self._config is not None: try: return self._config.get(section, key, fallback=default) - except: + except Exception: pass return default @@ -66,7 +67,7 @@ def get_int(self, section, key, default=None): if self._config is not None: try: return self._config.getint(section, key, fallback=default) - except: + except Exception: pass return default @@ -81,7 +82,7 @@ def get_bool(self, section, key, default=None): if self._config is not None: try: return self._config.getboolean(section, key, fallback=default) - except: + except Exception: pass return default @@ -89,12 +90,11 @@ def reset_config(self, location): """ :param Text location: """ - self._location = location or _os.environ.get('FLYTE_INTERNAL_CONFIGURATION_PATH') + self._location = location or _os.environ.get("FLYTE_INTERNAL_CONFIGURATION_PATH") self._config = None class _FlyteConfigurationPatcher(object): - def __init__(self, new_value, config): """ :param Text new_value: @@ -127,13 +127,12 @@ def _get_file_contents(location): :rtype: Text """ if _os.path.isfile(location): - with open(location, 'r') as f: - return f.read().replace('\n', '') + with open(location, "r") as f: + return f.read().replace("\n", "") return None class _FlyteConfigurationEntry(_six.with_metaclass(_abc.ABCMeta, object)): - def __init__(self, section, key, default=None, validator=None, fallback=None): self._section = section self._key = key @@ -200,22 +199,14 @@ def get_patcher(self, value): class _FlyteRequiredConfigurationEntry(_FlyteConfigurationEntry): - def __init__(self, section, key, validator=None): - super(_FlyteRequiredConfigurationEntry, self).__init__( - section, - key, - validator=self._validate_not_null - ) + super(_FlyteRequiredConfigurationEntry, self).__init__(section, key, validator=self._validate_not_null) self._extra_validator = validator def _validate_not_null(self, val): if val is None: raise _user_exceptions.FlyteAssertion( - "No configuration set for [{}] {}. This is a required configuration.".format( - self._section, - self._key - ) + "No configuration set for [{}] {}. This is a required configuration.".format(self._section, self._key) ) if self._extra_validator: self._extra_validator(val) @@ -247,7 +238,7 @@ def _getter(self): return CONFIGURATION_SINGLETON.get_bool(self._section, self._key, default=self._default) else: # Because bool('False') is True, compare to the same values that ConfigParser uses - if val.lower() in ['false', '0', 'off', 'no']: + if val.lower() in ["false", "0", "off", "no"]: return False return True @@ -259,7 +250,7 @@ def _getter(self): val = CONFIGURATION_SINGLETON.get_string(self._section, self._key) if val is None: return self._default - return val.split(',') + return val.split(",") class FlyteRequiredStringConfigurationEntry(_FlyteRequiredConfigurationEntry): @@ -293,7 +284,7 @@ def _getter(self): val = CONFIGURATION_SINGLETON.get_string(self._section, self._key) if val is None: return self._default - return val.split(',') + return val.split(",") CONFIGURATION_SINGLETON = FlyteConfigurationFile() diff --git a/flytekit/configuration/creds.py b/flytekit/configuration/creds.py index 2bc4bd7e74..06202f9a7b 100644 --- a/flytekit/configuration/creds.py +++ b/flytekit/configuration/creds.py @@ -2,14 +2,15 @@ from flytekit.configuration import common as _config_common -CLIENT_ID = _config_common.FlyteStringConfigurationEntry('credentials', 'client_id', default=None) +CLIENT_ID = _config_common.FlyteStringConfigurationEntry("credentials", "client_id", default=None) """ This is the public identifier for the app which handles authorization for a Flyte deployment. More details here: https://www.oauth.com/oauth2-servers/client-registration/client-id-secret/. """ -REDIRECT_URI = _config_common.FlyteStringConfigurationEntry('credentials', 'redirect_uri', - default="http://localhost:12345/callback") +REDIRECT_URI = _config_common.FlyteStringConfigurationEntry( + "credentials", "redirect_uri", default="http://localhost:12345/callback" +) """ This is the callback uri registered with the app which handles authorization for a Flyte deployment. Please note the hardcoded port number. Ideally we would not do this, but some IDPs do not allow wildcards for @@ -19,15 +20,16 @@ More details here: https://www.oauth.com/oauth2-servers/redirect-uris/. """ -AUTHORIZATION_METADATA_KEY = _config_common.FlyteStringConfigurationEntry('credentials', 'authorization_metadata_key', - default="authorization") +AUTHORIZATION_METADATA_KEY = _config_common.FlyteStringConfigurationEntry( + "credentials", "authorization_metadata_key", default="authorization" +) """ The authorization metadata key used for passing access tokens in gRPC requests. Traditionally this value is 'authorization' however it is made configurable. """ -CLIENT_CREDENTIALS_SECRET = _config_common.FlyteStringConfigurationEntry('credentials', 'client_secret', default=None) +CLIENT_CREDENTIALS_SECRET = _config_common.FlyteStringConfigurationEntry("credentials", "client_secret", default=None) """ Used for basic auth, which is automatically called during pyflyte. This will allow the Flyte engine to read the password directly from the environment variable. Note that this is less secure! Please only use this if mounting the @@ -35,19 +37,19 @@ """ -CLIENT_CREDENTIALS_SCOPE = _config_common.FlyteStringConfigurationEntry('credentials', 'scope', default=None) +CLIENT_CREDENTIALS_SCOPE = _config_common.FlyteStringConfigurationEntry("credentials", "scope", default=None) """ Used for basic auth, which is automatically called during pyflyte. This is the scope that will be requested. Because there is no user explicitly in this auth flow, certain IDPs require a custom scope for basic auth in the configuration of the authorization server. """ -AUTH_MODE = _config_common.FlyteStringConfigurationEntry('credentials', 'auth_mode', default="standard") +AUTH_MODE = _config_common.FlyteStringConfigurationEntry("credentials", "auth_mode", default="standard") """ The auth mode defines the behavior used to request and refresh credentials. The currently supported modes include: - 'standard' This uses the pkce-enhanced authorization code flow by opening a browser window to initiate credentials access. - 'basic' This uses cert-based auth in which the end user enters his/her username and password and public key encryption is used to facilitate authentication. -- None: No auth will be attempted. +- None: No auth will be attempted. """ diff --git a/flytekit/configuration/gcp.py b/flytekit/configuration/gcp.py index e647c7783b..9e5d50cd1a 100644 --- a/flytekit/configuration/gcp.py +++ b/flytekit/configuration/gcp.py @@ -3,6 +3,4 @@ from flytekit.configuration import common as _config_common GCS_PREFIX = _config_common.FlyteRequiredStringConfigurationEntry("gcp", "gcs_prefix") -GSUTIL_PARALLELISM = _config_common.FlyteBoolConfigurationEntry( - "gcp", "gsutil_parallelism", default=False -) +GSUTIL_PARALLELISM = _config_common.FlyteBoolConfigurationEntry("gcp", "gsutil_parallelism", default=False) diff --git a/flytekit/configuration/internal.py b/flytekit/configuration/internal.py index 49b2f200fe..bd3a238ca7 100644 --- a/flytekit/configuration/internal.py +++ b/flytekit/configuration/internal.py @@ -4,39 +4,40 @@ from flytekit.configuration import common as _common_config -IMAGE = _common_config.FlyteRequiredStringConfigurationEntry('internal', 'image') +IMAGE = _common_config.FlyteRequiredStringConfigurationEntry("internal", "image") # This configuration option specifies the path to the file that holds the configuration options. Don't worry, # there will not be cycles because the parsing of the configuration file intentionally will not read and settings # in the [internal] section. # The default, if you want to use it, should be a file called flytekit.config, located in wherever your python # interpreter originates. -CONFIGURATION_PATH = _common_config.FlyteStringConfigurationEntry('internal', 'configuration_path', - default='flytekit.config') +CONFIGURATION_PATH = _common_config.FlyteStringConfigurationEntry( + "internal", "configuration_path", default="flytekit.config" +) # Project, Domain and Version represent the values at registration time. -PROJECT = _common_config.FlyteStringConfigurationEntry('internal', 'project', default="") -DOMAIN = _common_config.FlyteStringConfigurationEntry('internal', 'domain', default="") -NAME = _common_config.FlyteStringConfigurationEntry('internal', 'name', default="") -VERSION = _common_config.FlyteStringConfigurationEntry('internal', 'version', default="") +PROJECT = _common_config.FlyteStringConfigurationEntry("internal", "project", default="") +DOMAIN = _common_config.FlyteStringConfigurationEntry("internal", "domain", default="") +NAME = _common_config.FlyteStringConfigurationEntry("internal", "name", default="") +VERSION = _common_config.FlyteStringConfigurationEntry("internal", "version", default="") # Project, Domain and Version represent the values at registration time. -TASK_PROJECT = _common_config.FlyteStringConfigurationEntry('internal', 'task_project', default="") -TASK_DOMAIN = _common_config.FlyteStringConfigurationEntry('internal', 'task_domain', default="") -TASK_NAME = _common_config.FlyteStringConfigurationEntry('internal', 'task_name', default="") -TASK_VERSION = _common_config.FlyteStringConfigurationEntry('internal', 'task_version', default="") +TASK_PROJECT = _common_config.FlyteStringConfigurationEntry("internal", "task_project", default="") +TASK_DOMAIN = _common_config.FlyteStringConfigurationEntry("internal", "task_domain", default="") +TASK_NAME = _common_config.FlyteStringConfigurationEntry("internal", "task_name", default="") +TASK_VERSION = _common_config.FlyteStringConfigurationEntry("internal", "task_version", default="") # Execution project and domain represent the values passed by execution engine at runtime. -EXECUTION_PROJECT = _common_config.FlyteStringConfigurationEntry('internal', 'execution_project', default="") -EXECUTION_DOMAIN = _common_config.FlyteStringConfigurationEntry('internal', 'execution_domain', default="") -EXECUTION_WORKFLOW = _common_config.FlyteStringConfigurationEntry('internal', 'execution_workflow', default="") -EXECUTION_LAUNCHPLAN = _common_config.FlyteStringConfigurationEntry('internal', 'execution_launchplan', default="") -EXECUTION_NAME = _common_config.FlyteStringConfigurationEntry('internal', 'execution_id', default="") +EXECUTION_PROJECT = _common_config.FlyteStringConfigurationEntry("internal", "execution_project", default="") +EXECUTION_DOMAIN = _common_config.FlyteStringConfigurationEntry("internal", "execution_domain", default="") +EXECUTION_WORKFLOW = _common_config.FlyteStringConfigurationEntry("internal", "execution_workflow", default="") +EXECUTION_LAUNCHPLAN = _common_config.FlyteStringConfigurationEntry("internal", "execution_launchplan", default="") +EXECUTION_NAME = _common_config.FlyteStringConfigurationEntry("internal", "execution_id", default="") # This is another layer of logging level, which can be set by propeller, and can override the SDK configuration if # necessary. (See the sdk.py version of this as well.) -LOGGING_LEVEL = _common_config.FlyteIntegerConfigurationEntry('internal', 'logging_level') +LOGGING_LEVEL = _common_config.FlyteIntegerConfigurationEntry("internal", "logging_level") -_IMAGE_VERSION_REGEX = '.*:(.+)' +_IMAGE_VERSION_REGEX = ".*:(.+)" def look_up_version_from_image_tag(tag): @@ -52,11 +53,10 @@ def look_up_version_from_image_tag(tag): :param Text tag: e.g. somedocker.com/myimage:someversion123 :rtype: Text """ - if tag is None or tag == '': - raise Exception('Bad input for image tag {}'.format(tag)) + if tag is None or tag == "": + raise Exception("Bad input for image tag {}".format(tag)) m = re.match(_IMAGE_VERSION_REGEX, tag) if m is not None: return m.group(1) - raise Exception('Could not parse image version from configuration. Did you set it in the' - 'Dockerfile?') + raise Exception("Could not parse image version from configuration. Did you set it in the" "Dockerfile?") diff --git a/flytekit/configuration/platform.py b/flytekit/configuration/platform.py index e4e6cf67a6..42a72175f4 100644 --- a/flytekit/configuration/platform.py +++ b/flytekit/configuration/platform.py @@ -1,11 +1,11 @@ from __future__ import absolute_import -from flytekit.configuration import common as _config_common from flytekit.common import constants as _constants +from flytekit.configuration import common as _config_common -URL = _config_common.FlyteRequiredStringConfigurationEntry('platform', 'url') +URL = _config_common.FlyteRequiredStringConfigurationEntry("platform", "url") -HTTP_URL = _config_common.FlyteStringConfigurationEntry('platform', 'http_url', default=None) +HTTP_URL = _config_common.FlyteStringConfigurationEntry("platform", "http_url", default=None) """ If not starting with either http or https, this setting should begin with // as per the urlparse library and https://tools.ietf.org/html/rfc1808.html, otherwise the netloc will not be properly parsed. @@ -14,13 +14,13 @@ Flyte Admin's gRPC and HTTP points are deployed on different ports. """ -INSECURE = _config_common.FlyteBoolConfigurationEntry('platform', 'insecure', default=False) +INSECURE = _config_common.FlyteBoolConfigurationEntry("platform", "insecure", default=False) CLOUD_PROVIDER = _config_common.FlyteStringConfigurationEntry( - 'platform', 'cloud_provider', default=_constants.CloudProvider.AWS + "platform", "cloud_provider", default=_constants.CloudProvider.AWS ) -AUTH = _config_common.FlyteBoolConfigurationEntry('platform', 'auth', default=False) +AUTH = _config_common.FlyteBoolConfigurationEntry("platform", "auth", default=False) """ This config setting should not normally be filled in. Whether or not an admin server requires authentication should be something published by the admin server itself (typically by returning a 401). However, to help with migration, this diff --git a/flytekit/configuration/resources.py b/flytekit/configuration/resources.py index 949b59024c..783a8e672f 100644 --- a/flytekit/configuration/resources.py +++ b/flytekit/configuration/resources.py @@ -2,49 +2,49 @@ from flytekit.configuration import common as _config_common -DEFAULT_CPU_LIMIT = _config_common.FlyteStringConfigurationEntry('resources', 'default_cpu_limit') +DEFAULT_CPU_LIMIT = _config_common.FlyteStringConfigurationEntry("resources", "default_cpu_limit") """ If not specified explicitly when constructing a task, this limit will be applied as the default. Follows Kubernetes CPU request/limit format. """ -DEFAULT_CPU_REQUEST = _config_common.FlyteStringConfigurationEntry('resources', 'default_cpu_request') +DEFAULT_CPU_REQUEST = _config_common.FlyteStringConfigurationEntry("resources", "default_cpu_request") """ If not specified explicitly when constructing a task, this request will be applied as the default. Follows Kubernetes CPU request/limit format. """ -DEFAULT_MEMORY_LIMIT = _config_common.FlyteStringConfigurationEntry('resources', 'default_memory_limit') +DEFAULT_MEMORY_LIMIT = _config_common.FlyteStringConfigurationEntry("resources", "default_memory_limit") """ If not specified explicitly when constructing a task, this limit will be applied as the default. Follows Kubernetes memory request/limit format. """ -DEFAULT_MEMORY_REQUEST = _config_common.FlyteStringConfigurationEntry('resources', 'default_memory_request') +DEFAULT_MEMORY_REQUEST = _config_common.FlyteStringConfigurationEntry("resources", "default_memory_request") """ If not specified explicitly when constructing a task, this request will be applied as the default. Follows Kubernetes memory request/limit format. """ -DEFAULT_GPU_LIMIT = _config_common.FlyteStringConfigurationEntry('resources', 'default_gpu_limit') +DEFAULT_GPU_LIMIT = _config_common.FlyteStringConfigurationEntry("resources", "default_gpu_limit") """ If not specified explicitly when constructing a task, this limit will be applied as the default. Follows Kubernetes GPU request/limit format. """ -DEFAULT_GPU_REQUEST = _config_common.FlyteStringConfigurationEntry('resources', 'default_gpu_request') +DEFAULT_GPU_REQUEST = _config_common.FlyteStringConfigurationEntry("resources", "default_gpu_request") """ If not specified explicitly when constructing a task, this request will be applied as the default. Follows Kubernetes GPU request/limit format. """ -DEFAULT_STORAGE_LIMIT = _config_common.FlyteStringConfigurationEntry('resources', 'default_storage_limit') +DEFAULT_STORAGE_LIMIT = _config_common.FlyteStringConfigurationEntry("resources", "default_storage_limit") """ If not specified explicitly when constructing a task, this limit will be applied as the default. Follows Kubernetes storage request/limit format. """ -DEFAULT_STORAGE_REQUEST = _config_common.FlyteStringConfigurationEntry('resources', 'default_storage_request') +DEFAULT_STORAGE_REQUEST = _config_common.FlyteStringConfigurationEntry("resources", "default_storage_request") """ If not specified explicitly when constructing a task, this request will be applied as the default. Follows Kubernetes storage request/limit format. diff --git a/flytekit/configuration/sdk.py b/flytekit/configuration/sdk.py index 4caaac5439..14c8d21e65 100644 --- a/flytekit/configuration/sdk.py +++ b/flytekit/configuration/sdk.py @@ -2,58 +2,58 @@ from flytekit.configuration import common as _config_common -WORKFLOW_PACKAGES = _config_common.FlyteStringListConfigurationEntry('sdk', 'workflow_packages', default=[]) +WORKFLOW_PACKAGES = _config_common.FlyteStringListConfigurationEntry("sdk", "workflow_packages", default=[]) """ This is a comma-delimited list of packages that SDK tools will use to discover entities for the purpose of registration and execution of entities. """ -EXECUTION_ENGINE = _config_common.FlyteStringConfigurationEntry('sdk', 'execution_engine', default='flyte') +EXECUTION_ENGINE = _config_common.FlyteStringConfigurationEntry("sdk", "execution_engine", default="flyte") """ This is a comma-delimited list of package strings, in order, for resolving execution behavior. TODO: Explain how this would be used to extend the SDK """ -TYPE_ENGINES = _config_common.FlyteStringListConfigurationEntry('sdk', 'type_engines', default=[]) +TYPE_ENGINES = _config_common.FlyteStringListConfigurationEntry("sdk", "type_engines", default=[]) """ This is a comma-delimited list of package strings, in order, for resolving type behavior. TODO: Explain how this would be used to extend the SDK """ -LOCAL_SANDBOX = _config_common.FlyteStringConfigurationEntry('sdk', 'local_sandbox', default="/tmp/flyte") +LOCAL_SANDBOX = _config_common.FlyteStringConfigurationEntry("sdk", "local_sandbox", default="/tmp/flyte") """ This is the path where SDK will place files during local executions and testing. The SDK will not automatically clean up data in these directories. """ -SDK_PYTHON_VENV = _config_common.FlyteStringListConfigurationEntry('sdk', 'python_venv', default=[]) +SDK_PYTHON_VENV = _config_common.FlyteStringListConfigurationEntry("sdk", "python_venv", default=[]) """ This is a list of commands/args which will be prefixed to the entrypoint command by SDK. """ -ROLE = _config_common.FlyteRequiredStringConfigurationEntry('sdk', 'role') +ROLE = _config_common.FlyteRequiredStringConfigurationEntry("sdk", "role") """ This is the role the SDK will use by default to execute workflows. For example, in AWS this should be an IAM role string. """ -NAME_FORMAT = _config_common.FlyteStringConfigurationEntry('sdk', 'name_format', default='{module}.{name}') +NAME_FORMAT = _config_common.FlyteStringConfigurationEntry("sdk", "name_format", default="{module}.{name}") """ This is a Python format string which the SDK will use to generate names for discovered entities. The default is '{module}.{name}' which will result in strings like 'package.module.name'. Any template portion of the string can only include 'module' or 'name'. So '{name}' is valid, but '{key}' is not. """ -TASK_NAME_FORMAT = _config_common.FlyteStringConfigurationEntry('sdk', 'task_name_format', fallback=NAME_FORMAT) +TASK_NAME_FORMAT = _config_common.FlyteStringConfigurationEntry("sdk", "task_name_format", fallback=NAME_FORMAT) """ This is a Python format string which the SDK will use to generate names for tasks. Any template portion of the string can only include 'module' or 'name'. So '{name}' is valid, but '{key}' is not. If not specified, we fall back to the configuration for :py:attr:`flytekit.configuration.sdk.NAME_FORMAT` """ -WORKFLOW_NAME_FORMAT = _config_common.FlyteStringConfigurationEntry('sdk', 'workflow_name_format', fallback=NAME_FORMAT) +WORKFLOW_NAME_FORMAT = _config_common.FlyteStringConfigurationEntry("sdk", "workflow_name_format", fallback=NAME_FORMAT) """ This is a Python format string which the SDK will use to generate names for workflows. Any template portion of the string can only include 'module' or 'name'. So '{name}' is valid, but '{key}' is not. If not specified, @@ -61,7 +61,7 @@ """ LAUNCH_PLAN_NAME_FORMAT = _config_common.FlyteStringConfigurationEntry( - 'sdk', 'launch_plan_name_format', fallback=NAME_FORMAT + "sdk", "launch_plan_name_format", fallback=NAME_FORMAT ) """ This is a Python format string which the SDK will use to generate names for launch plans. Any template portion of the @@ -69,14 +69,14 @@ we fall back to the configuration for :py:attr:`flytekit.configuration.sdk.NAME_FORMAT` """ -LOGGING_LEVEL = _config_common.FlyteIntegerConfigurationEntry('sdk', 'logging_level', default=20) +LOGGING_LEVEL = _config_common.FlyteIntegerConfigurationEntry("sdk", "logging_level", default=20) """ This is the default logging level for the Python logging library and will be set before user code runs. Note that this configuration is special in that it is a runtime setting, not a compile time setting. This is the only runtime option in this file. """ -PARQUET_ENGINE = _config_common.FlyteStringConfigurationEntry('sdk', 'parquet_engine', default='pyarrow') +PARQUET_ENGINE = _config_common.FlyteStringConfigurationEntry("sdk", "parquet_engine", default="pyarrow") """ This is the parquet engine to use when reading data from parquet files. """ diff --git a/flytekit/configuration/statsd.py b/flytekit/configuration/statsd.py index 76bb93990b..ab8cb956d6 100644 --- a/flytekit/configuration/statsd.py +++ b/flytekit/configuration/statsd.py @@ -2,5 +2,5 @@ from flytekit.configuration import common as _common_config -HOST = _common_config.FlyteStringConfigurationEntry('statsd', 'host', default='localhost') -PORT = _common_config.FlyteIntegerConfigurationEntry('statsd', 'port', default=8125) +HOST = _common_config.FlyteStringConfigurationEntry("statsd", "host", default="localhost") +PORT = _common_config.FlyteIntegerConfigurationEntry("statsd", "port", default=8125) diff --git a/flytekit/contrib/notebook/helper.py b/flytekit/contrib/notebook/helper.py index 28ce0a8b0f..88d822b336 100644 --- a/flytekit/contrib/notebook/helper.py +++ b/flytekit/contrib/notebook/helper.py @@ -1,8 +1,11 @@ -from flytekit.common.types.helpers import pack_python_std_map_to_literal_map as _packer -from flytekit.contrib.notebook.supported_types import notebook_types_map as _notebook_types_map +import os as _os + import six as _six from pyspark import SparkConf, SparkContext -import os as _os + +from flytekit.common.types.helpers import pack_python_std_map_to_literal_map as _packer +from flytekit.contrib.notebook.supported_types import notebook_types_map as _notebook_types_map + def record_outputs(outputs=None): """ @@ -15,7 +18,10 @@ def record_outputs(outputs=None): t = type(v) if t not in _notebook_types_map: raise ValueError( - "Currently only primitive types {} are supported for recording from notebook".format(_notebook_types_map)) + "Currently only primitive types {} are supported for recording from notebook".format( + _notebook_types_map + ) + ) tm[k] = _notebook_types_map[t] return _packer(outputs, tm).to_flyte_idl() diff --git a/flytekit/contrib/notebook/supported_types.py b/flytekit/contrib/notebook/supported_types.py index 8e3fbea8f0..ac0977197a 100644 --- a/flytekit/contrib/notebook/supported_types.py +++ b/flytekit/contrib/notebook/supported_types.py @@ -1,6 +1,6 @@ -from flytekit.common.types import primitives as _primitives import datetime as _datetime +from flytekit.common.types import primitives as _primitives notebook_types_map = { int: _primitives.Integer, @@ -9,4 +9,4 @@ str: _primitives.String, _datetime.datetime: _primitives.Datetime, _datetime.timedelta: _primitives.Timedelta, -} \ No newline at end of file +} diff --git a/flytekit/contrib/notebook/tasks.py b/flytekit/contrib/notebook/tasks.py index 293bc419a4..e78ae859ed 100644 --- a/flytekit/contrib/notebook/tasks.py +++ b/flytekit/contrib/notebook/tasks.py @@ -1,46 +1,56 @@ -import os as _os -import json as _json -import papermill as _pm -from google.protobuf import text_format as _text_format, json_format as _json_format -import importlib as _importlib import datetime as _datetime +import importlib as _importlib +import inspect as _inspect +import json as _json +import os as _os import sys as _sys + +import papermill as _pm import six as _six +from google.protobuf import json_format as _json_format +from google.protobuf import text_format as _text_format + from flytekit import __version__ from flytekit.bin import entrypoint as _entrypoint -from flytekit.sdk.types import Types as _Types -from flytekit.common.types import helpers as _type_helpers, primitives as _primitives -from flytekit.common import constants as _constants, sdk_bases as _sdk_bases, interface as _interface2 -from flytekit.common.exceptions import scopes as _exception_scopes, user as _user_exceptions -from flytekit.common.tasks import sdk_runnable as _sdk_runnable, spark_task as _spark_task,\ - output as _task_output, task as _base_tasks -from flytekit.models import literals as _literal_models, task as _task_models, interface as _interface +from flytekit.common import constants as _constants +from flytekit.common import interface as _interface2 +from flytekit.common.exceptions import scopes as _exception_scopes +from flytekit.common.exceptions import user as _user_exceptions +from flytekit.common.tasks import output as _task_output +from flytekit.common.tasks import sdk_runnable as _sdk_runnable +from flytekit.common.tasks import spark_task as _spark_task +from flytekit.common.tasks import task as _base_tasks +from flytekit.common.types import helpers as _type_helpers +from flytekit.contrib.notebook.supported_types import notebook_types_map as _notebook_types_map from flytekit.engines import loader as _engine_loader -import inspect as _inspect +from flytekit.models import interface as _interface +from flytekit.models import literals as _literal_models +from flytekit.models import task as _task_models from flytekit.sdk.spark_types import SparkType as _spark_type -from flytekit.contrib.notebook.supported_types import notebook_types_map as _notebook_types_map +from flytekit.sdk.types import Types as _Types + +OUTPUT_NOTEBOOK = "output_notebook" -OUTPUT_NOTEBOOK = 'output_notebook' -def python_notebook( - notebook_path='', - inputs={}, - outputs={}, - cache_version='', - retries=0, - deprecated='', - storage_request=None, - cpu_request=None, - gpu_request=None, - memory_request=None, - storage_limit=None, - cpu_limit=None, - gpu_limit=None, - memory_limit=None, - cache=False, - timeout=None, - environment=None, - cls=None, +def python_notebook( + notebook_path="", + inputs={}, + outputs={}, + cache_version="", + retries=0, + deprecated="", + storage_request=None, + cpu_request=None, + gpu_request=None, + memory_request=None, + storage_limit=None, + cpu_limit=None, + gpu_limit=None, + memory_limit=None, + cache=False, + timeout=None, + environment=None, + cls=None, ): """ Decorator to create a Python Notebook Task definition. @@ -48,25 +58,26 @@ def python_notebook( :rtype: SdkNotebookTask """ return SdkNotebookTask( - notebook_path=notebook_path, - inputs=inputs, - outputs=outputs, - task_type=_constants.SdkTaskType.PYTHON_TASK, - discovery_version=cache_version, - retries=retries, - deprecated=deprecated, - storage_request=storage_request, - cpu_request=cpu_request, - gpu_request=gpu_request, - memory_request=memory_request, - storage_limit=storage_limit, - cpu_limit=cpu_limit, - gpu_limit=gpu_limit, - memory_limit=memory_limit, - discoverable=cache, - timeout=timeout or _datetime.timedelta(seconds=0), - environment=environment, - custom={}) + notebook_path=notebook_path, + inputs=inputs, + outputs=outputs, + task_type=_constants.SdkTaskType.PYTHON_TASK, + discovery_version=cache_version, + retries=retries, + deprecated=deprecated, + storage_request=storage_request, + cpu_request=cpu_request, + gpu_request=gpu_request, + memory_request=memory_request, + storage_limit=storage_limit, + cpu_limit=cpu_limit, + gpu_limit=gpu_limit, + memory_limit=memory_limit, + discoverable=cache, + timeout=timeout or _datetime.timedelta(seconds=0), + environment=environment, + custom={}, + ) class SdkNotebookTask(_base_tasks.SdkTask): @@ -77,26 +88,26 @@ class SdkNotebookTask(_base_tasks.SdkTask): """ def __init__( - self, - notebook_path, - inputs, - outputs, - task_type, - discovery_version, - retries, - deprecated, - storage_request, - cpu_request, - gpu_request, - memory_request, - storage_limit, - cpu_limit, - gpu_limit, - memory_limit, - discoverable, - timeout, - environment, - custom + self, + notebook_path, + inputs, + outputs, + task_type, + discovery_version, + retries, + deprecated, + storage_request, + cpu_request, + gpu_request, + memory_request, + storage_limit, + cpu_limit, + gpu_limit, + memory_limit, + discoverable, + timeout, + environment, + custom, ): if _os.path.isabs(notebook_path) is False: @@ -112,15 +123,13 @@ def __init__( _task_models.TaskMetadata( discoverable, _task_models.RuntimeMetadata( - _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, - __version__, - 'notebook' + _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, "notebook", ), timeout, _literal_models.RetryStrategy(retries), False, discovery_version, - deprecated + deprecated, ), _interface2.TypedInterface({}, {}), custom, @@ -133,8 +142,8 @@ def __init__( cpu_limit=cpu_limit, gpu_limit=gpu_limit, memory_limit=memory_limit, - environment=environment - ) + environment=environment, + ), ) # Add Inputs if inputs is not None: @@ -145,8 +154,9 @@ def __init__( outputs(self) # Add a Notebook output as a Blob. - self.interface.outputs.update(output_notebook=_interface.Variable(_Types.Blob.to_flyte_literal_type(), OUTPUT_NOTEBOOK)) - + self.interface.outputs.update( + output_notebook=_interface.Variable(_Types.Blob.to_flyte_literal_type(), OUTPUT_NOTEBOOK) + ) def _validate_inputs(self, inputs): """ @@ -154,7 +164,7 @@ def _validate_inputs(self, inputs): :raises: flytekit.common.exceptions.user.FlyteValidationException """ for k, v in _six.iteritems(inputs): - sdk_type =_type_helpers.get_sdk_type_from_literal_type(v.type) + sdk_type = _type_helpers.get_sdk_type_from_literal_type(v.type) if sdk_type not in _notebook_types_map.values(): raise _user_exceptions.FlyteValidationException( "Input Type '{}' not supported. Only Primitives are supported for notebook.".format(sdk_type) @@ -173,7 +183,8 @@ def _validate_outputs(self, outputs): if k == OUTPUT_NOTEBOOK: raise ValueError( - "{} is a reserved output keyword. Please use a different output name.".format(OUTPUT_NOTEBOOK)) + "{} is a reserved output keyword. Please use a different output name.".format(OUTPUT_NOTEBOOK) + ) sdk_type = _type_helpers.get_sdk_type_from_literal_type(v.type) if sdk_type not in _notebook_types_map.values(): @@ -212,11 +223,18 @@ def unit_test(self, **input_map): :returns: Depends on the behavior of the specific task in the unit engine. """ - return _engine_loader.get_engine('unit').get_task(self).execute( - _type_helpers.pack_python_std_map_to_literal_map(input_map, { - k: _type_helpers.get_sdk_type_from_literal_type(v.type) - for k, v in _six.iteritems(self.interface.inputs) - }) + return ( + _engine_loader.get_engine("unit") + .get_task(self) + .execute( + _type_helpers.pack_python_std_map_to_literal_map( + input_map, + { + k: _type_helpers.get_sdk_type_from_literal_type(v.type) + for k, v in _six.iteritems(self.interface.inputs) + }, + ) + ) ) @_exception_scopes.system_entry_point @@ -227,11 +245,18 @@ def local_execute(self, **input_map): :rtype: dict[Text, T] :returns: The output produced by this task in Python standard format. """ - return _engine_loader.get_engine('local').get_task(self).execute( - _type_helpers.pack_python_std_map_to_literal_map(input_map, { - k: _type_helpers.get_sdk_type_from_literal_type(v.type) - for k, v in _six.iteritems(self.interface.inputs) - }) + return ( + _engine_loader.get_engine("local") + .get_task(self) + .execute( + _type_helpers.pack_python_std_map_to_literal_map( + input_map, + { + k: _type_helpers.get_sdk_type_from_literal_type(v.type) + for k, v in _six.iteritems(self.interface.inputs) + }, + ) + ) ) @_exception_scopes.system_entry_point @@ -246,27 +271,24 @@ def execute(self, context, inputs): working directory (with the names provided), which will in turn allow Flyte Propeller to push along the workflow. Where as local engine will merely feed the outputs directly into the next node. """ - inputs_dict = _type_helpers.unpack_literal_map_to_sdk_python_std(inputs, { - k: _type_helpers.get_sdk_type_from_literal_type(v.type) for k, v in _six.iteritems(self.interface.inputs) - }) + inputs_dict = _type_helpers.unpack_literal_map_to_sdk_python_std( + inputs, + {k: _type_helpers.get_sdk_type_from_literal_type(v.type) for k, v in _six.iteritems(self.interface.inputs)}, + ) input_notebook_path = self._notebook_path # Execute Notebook via Papermill. - output_notebook_path = input_notebook_path.split(".ipynb")[0] + '-out.ipynb' - _pm.execute_notebook( - input_notebook_path, - output_notebook_path, - parameters=inputs_dict - ) + output_notebook_path = input_notebook_path.split(".ipynb")[0] + "-out.ipynb" + _pm.execute_notebook(input_notebook_path, output_notebook_path, parameters=inputs_dict) # Parse Outputs from Notebook. outputs = None with open(output_notebook_path) as json_file: data = _json.load(json_file) - for p in data['cells']: - meta = p['metadata'] + for p in data["cells"]: + meta = p["metadata"] if "outputs" in meta["tags"]: - outputs = ' '.join(p['outputs'][0]['data']['text/plain']) + outputs = " ".join(p["outputs"][0]["data"]["text/plain"]) if outputs is not None: dict = _literal_models._literals_pb2.LiteralMap() @@ -274,15 +296,14 @@ def execute(self, context, inputs): # Add output_notebook as an output to the task. output_notebook = _task_output.OutputReference( - _type_helpers.get_sdk_type_from_literal_type(_Types.Blob.to_flyte_literal_type())) + _type_helpers.get_sdk_type_from_literal_type(_Types.Blob.to_flyte_literal_type()) + ) output_notebook.set(output_notebook_path) output_literal_map = _literal_models.LiteralMap.from_flyte_idl(dict) output_literal_map.literals[OUTPUT_NOTEBOOK] = output_notebook.sdk_value - return { - _constants.OUTPUT_FILE_NAME: output_literal_map - } + return {_constants.OUTPUT_FILE_NAME: output_literal_map} @property def container(self): @@ -307,21 +328,22 @@ def container(self): "--inputs", "{{.input}}", "--output-prefix", - "{{.outputPrefix}}"] + "{{.outputPrefix}}", + ] return self._container def _get_container_definition( - self, - storage_request=None, - cpu_request=None, - gpu_request=None, - memory_request=None, - storage_limit=None, - cpu_limit=None, - gpu_limit=None, - memory_limit=None, - environment=None, - **kwargs + self, + storage_request=None, + cpu_request=None, + gpu_request=None, + memory_request=None, + storage_limit=None, + cpu_limit=None, + gpu_limit=None, + memory_limit=None, + environment=None, + **kwargs ): """ :param Text storage_request: @@ -344,61 +366,29 @@ def _get_container_definition( requests = [] if storage_request: requests.append( - _task_models.Resources.ResourceEntry( - _task_models.Resources.ResourceName.STORAGE, - storage_request - ) + _task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.STORAGE, storage_request) ) if cpu_request: - requests.append( - _task_models.Resources.ResourceEntry( - _task_models.Resources.ResourceName.CPU, - cpu_request - ) - ) + requests.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.CPU, cpu_request)) if gpu_request: - requests.append( - _task_models.Resources.ResourceEntry( - _task_models.Resources.ResourceName.GPU, - gpu_request - ) - ) + requests.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.GPU, gpu_request)) if memory_request: requests.append( - _task_models.Resources.ResourceEntry( - _task_models.Resources.ResourceName.MEMORY, - memory_request - ) + _task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.MEMORY, memory_request) ) limits = [] if storage_limit: limits.append( - _task_models.Resources.ResourceEntry( - _task_models.Resources.ResourceName.STORAGE, - storage_limit - ) + _task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.STORAGE, storage_limit) ) if cpu_limit: - limits.append( - _task_models.Resources.ResourceEntry( - _task_models.Resources.ResourceName.CPU, - cpu_limit - ) - ) + limits.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.CPU, cpu_limit)) if gpu_limit: - limits.append( - _task_models.Resources.ResourceEntry( - _task_models.Resources.ResourceName.GPU, - gpu_limit - ) - ) + limits.append(_task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.GPU, gpu_limit)) if memory_limit: limits.append( - _task_models.Resources.ResourceEntry( - _task_models.Resources.ResourceName.MEMORY, - memory_limit - ) + _task_models.Resources.ResourceEntry(_task_models.Resources.ResourceName.MEMORY, memory_limit) ) return _sdk_runnable.SdkRunnableContainer( @@ -406,45 +396,45 @@ def _get_container_definition( args=[], resources=_task_models.Resources(limits=limits, requests=requests), env=environment, - config={} + config={}, ) def spark_notebook( - notebook_path, - inputs={}, - outputs={}, - spark_conf=None, - cache_version='', - retries=0, - deprecated='', - cache=False, - timeout=None, - environment=None, + notebook_path, + inputs={}, + outputs={}, + spark_conf=None, + cache_version="", + retries=0, + deprecated="", + cache=False, + timeout=None, + environment=None, ): """ Decorator to create a Notebook spark task. This task will connect to a Spark cluster, configure the environment, and then execute the code within the notebook_path as the Spark driver program. """ return SdkNotebookSparkTask( - notebook_path=notebook_path, - inputs=inputs, - outputs=outputs, - spark_conf=spark_conf, - discovery_version=cache_version, - retries=retries, - deprecated=deprecated, - discoverable=cache, - timeout=timeout or _datetime.timedelta(seconds=0), - environment=environment or {}, - ) + notebook_path=notebook_path, + inputs=inputs, + outputs=outputs, + spark_conf=spark_conf, + discovery_version=cache_version, + retries=retries, + deprecated=deprecated, + discoverable=cache, + timeout=timeout or _datetime.timedelta(seconds=0), + environment=environment or {}, + ) def _find_instance_module(): frame = _inspect.currentframe() while frame: - if frame.f_code.co_name == '': - return frame.f_globals['__name__'] + if frame.f_code.co_name == "": + return frame.f_globals["__name__"] frame = frame.f_back return None @@ -457,40 +447,40 @@ class SdkNotebookSparkTask(SdkNotebookTask): """ def __init__( - self, - notebook_path, - inputs, - outputs, - spark_conf, - discovery_version, - retries, - deprecated, - discoverable, - timeout, - environment=None, + self, + notebook_path, + inputs, + outputs, + spark_conf, + discovery_version, + retries, + deprecated, + discoverable, + timeout, + environment=None, ): spark_exec_path = _os.path.abspath(_entrypoint.__file__) - if spark_exec_path.endswith('.pyc'): + if spark_exec_path.endswith(".pyc"): spark_exec_path = spark_exec_path[:-1] if spark_conf is None: # Parse spark_conf from notebook if not set at task_level. with open(notebook_path) as json_file: data = _json.load(json_file) - for p in data['cells']: - meta = p['metadata'] + for p in data["cells"]: + meta = p["metadata"] if "tags" in meta: if "conf" in meta["tags"]: - sc_str = ' '.join(p["source"]) + sc_str = " ".join(p["source"]) ldict = {} - exec (sc_str, globals(), ldict) - spark_conf = ldict['spark_conf'] + exec(sc_str, globals(), ldict) + spark_conf = ldict["spark_conf"] spark_job = _task_models.SparkJob( spark_conf=spark_conf, - main_class= "", - spark_type= _spark_type.PYTHON, + main_class="", + spark_type=_spark_type.PYTHON, hadoop_conf={}, application_file="local://" + spark_exec_path, executor_path=_sys.executable, @@ -518,11 +508,7 @@ def __init__( _json_format.MessageToDict(spark_job), ) - def _get_container_definition( - self, - environment=None, - **kwargs - ): + def _get_container_definition(self, environment=None, **kwargs): """ :rtype: flytekit.models.task.Container """ @@ -532,5 +518,5 @@ def _get_container_definition( args=[], resources=_task_models.Resources(limits=[], requests=[]), env=environment or {}, - config={} - ) \ No newline at end of file + config={}, + ) diff --git a/flytekit/contrib/sensors/base_sensor.py b/flytekit/contrib/sensors/base_sensor.py index e8614333a4..323976b001 100644 --- a/flytekit/contrib/sensors/base_sensor.py +++ b/flytekit/contrib/sensors/base_sensor.py @@ -2,16 +2,15 @@ import abc as _abc import datetime as _datetime +import logging as _logging import sys as _sys import time as _time import traceback as _traceback import six as _six -import logging as _logging class Sensor(_six.with_metaclass(_abc.ABCMeta, object)): - def __init__(self, evaluation_interval=None, max_failures=0): """ :param datetime.timedelta evaluation_interval: This is the time to wait between evaluation attempts of this diff --git a/flytekit/contrib/sensors/impl.py b/flytekit/contrib/sensors/impl.py index 10363e7099..ce601922c1 100644 --- a/flytekit/contrib/sensors/impl.py +++ b/flytekit/contrib/sensors/impl.py @@ -1,18 +1,11 @@ from __future__ import absolute_import -from flytekit.plugins import hmsclient as _hmsclient from flytekit.contrib.sensors.base_sensor import Sensor as _Sensor +from flytekit.plugins import hmsclient as _hmsclient class _HiveSensor(_Sensor): - - def __init__( - self, - host, - port, - schema='default', - **kwargs - ): + def __init__(self, host, port, schema="default", **kwargs): """ :param Text host: :param Text port: @@ -28,14 +21,7 @@ def __init__( class HiveTableSensor(_HiveSensor): - - def __init__( - self, - table_name, - host, - port, - **kwargs - ): + def __init__(self, table_name, host, port, **kwargs): """ :param Text host: The host for the Hive metastore Thrift service. :param Text port: The port for the Hive metastore Thrift Service @@ -43,11 +29,7 @@ def __init__( :param **kwargs: See _HiveSensor and flytekit.contrib.sensors.base_sensor.Sensor for more parameters. """ - super(HiveTableSensor, self).__init__( - host, - port, - **kwargs - ) + super(HiveTableSensor, self).__init__(host, port, **kwargs) self._table_name = table_name def _do_poll(self): @@ -63,15 +45,7 @@ def _do_poll(self): class HiveNamedPartitionSensor(_HiveSensor): - - def __init__( - self, - table_name, - partition_names, - host, - port, - **kwargs - ): + def __init__(self, table_name, partition_names, host, port, **kwargs): """ This class allows sensing for a specific named Hive Partition. This is the preferred partition sensing operator because it is more efficient than evaluating a filter expression. @@ -101,15 +75,7 @@ def _do_poll(self): class HiveFilteredPartitionSensor(_HiveSensor): - - def __init__( - self, - table_name, - partition_filter, - host, - port, - **kwargs - ): + def __init__(self, table_name, partition_filter, host, port, **kwargs): """ This class allows sensing for any Hive partition that matches a filter expression. It is recommended that the user should use HiveNamedPartitionSensor instead when possible because it is a more efficient API. @@ -122,11 +88,7 @@ def __init__( :param **kwargs: See _HiveSensor and flytekit.contrib.sensors.base_sensor.Sensor for more parameters. """ - super(HiveFilteredPartitionSensor, self).__init__( - host, - port, - **kwargs - ) + super(HiveFilteredPartitionSensor, self).__init__(host, port, **kwargs) self._table_name = table_name self._partition_filter = partition_filter @@ -136,10 +98,7 @@ def _do_poll(self): """ with self._hive_metastore_client as client: partitions = client.get_partitions_by_filter( - db_name=self._schema, - tbl_name=self._table_name, - filter=self._partition_filter, - max_parts=1, + db_name=self._schema, tbl_name=self._table_name, filter=self._partition_filter, max_parts=1, ) if partitions: return True, None diff --git a/flytekit/contrib/sensors/task.py b/flytekit/contrib/sensors/task.py index e20aca48f9..6f427cbb51 100644 --- a/flytekit/contrib/sensors/task.py +++ b/flytekit/contrib/sensors/task.py @@ -1,4 +1,5 @@ from __future__ import absolute_import + from flytekit.common import constants as _common_constants from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.tasks import sdk_runnable as _sdk_runnable @@ -11,8 +12,7 @@ def _execute_user_code(self, context, inputs): if sensor is not None: if not isinstance(sensor, _Sensor): raise _user_exceptions.FlyteTypeException( - received_type=type(sensor), - expected_type=_Sensor, + received_type=type(sensor), expected_type=_Sensor, ) succeeded = sensor.sense() if not succeeded: @@ -23,7 +23,7 @@ def sensor_task( _task_function=None, retries=0, interruptible=None, - deprecated='', + deprecated="", storage_request=None, cpu_request=None, gpu_request=None, @@ -96,6 +96,7 @@ def my_task(wf_params): otherwise mimic the behavior. :rtype: SensorTask """ + def wrapper(fn): return (SensorTask or cls)( task_function=fn, @@ -114,7 +115,7 @@ def wrapper(fn): timeout=timeout, environment=environment, custom={}, - discovery_version='', + discovery_version="", discoverable=False, ) diff --git a/flytekit/engines/common.py b/flytekit/engines/common.py index ee7b3e709a..68c8c6f1d4 100644 --- a/flytekit/engines/common.py +++ b/flytekit/engines/common.py @@ -1,7 +1,9 @@ from __future__ import absolute_import import abc as _abc + import six as _six + from flytekit.models import common as _common_models @@ -10,6 +12,7 @@ class BaseWorkflowExecutor(_six.with_metaclass(_common_models.FlyteABCMeta, obje This class must be implemented for any engine to create, interact with, and execute workflows using the FlyteKit SDK. """ + def __init__(self, sdk_workflow): """ :param flytekit.common.workflow.SdkWorkflow sdk_workflow: @@ -88,7 +91,6 @@ def terminate(self, cause): class BaseNodeExecution(_six.with_metaclass(_common_models.FlyteABCMeta, object)): - def __init__(self, node_execution): """ :param flytekit.common.nodes.SdkNodeExecution node_execution: @@ -139,7 +141,6 @@ def sync(self): class BaseTaskExecution(_six.with_metaclass(_common_models.FlyteABCMeta, object)): - def __init__(self, task_exec): """ :param flytekit.common.tasks.executions.SdkTaskExecution task_exec: @@ -184,7 +185,6 @@ def get_child_executions(self, filters=None): class BaseLaunchPlanLauncher(_six.with_metaclass(_common_models.FlyteABCMeta, object)): - def __init__(self, sdk_launch_plan): """ :param flytekit.common.launch_plan.SdkLaunchPlan sdk_launch_plan: @@ -207,8 +207,16 @@ def register(self, identifier): pass @_abc.abstractmethod - def launch(self, project, domain, name, inputs, notification_overrides=None, label_overrides=None, - annotation_overrides=None): + def launch( + self, + project, + domain, + name, + inputs, + notification_overrides=None, + label_overrides=None, + annotation_overrides=None, + ): """ Registers the launch plan and returns the identifier. :param Text project: @@ -262,8 +270,17 @@ def register(self, identifier): pass @_abc.abstractmethod - def launch(self, project, domain, name=None, inputs=None, notification_overrides=None, - label_overrides=None, annotation_overrides=None, auth_role=None): + def launch( + self, + project, + domain, + name=None, + inputs=None, + notification_overrides=None, + label_overrides=None, + annotation_overrides=None, + auth_role=None, + ): """ Executes the task as a single task execution and returns the identifier. :param Text project: diff --git a/flytekit/engines/flyte/engine.py b/flytekit/engines/flyte/engine.py index 49d8c24fab..ba68fd7de5 100644 --- a/flytekit/engines/flyte/engine.py +++ b/flytekit/engines/flyte/engine.py @@ -4,27 +4,34 @@ import os as _os import traceback as _traceback from datetime import datetime as _datetime -from deprecated import deprecated as _deprecated import six as _six +from deprecated import deprecated as _deprecated from flyteidl.core import literals_pb2 as _literals_pb2 from flytekit import __version__ as _api_version from flytekit.clients.friendly import SynchronousFlyteClient as _SynchronousFlyteClient -from flytekit.clients.helpers import iterate_node_executions as _iterate_node_executions, iterate_task_executions as \ - _iterate_task_executions -from flytekit.common import utils as _common_utils, constants as _constants -from flytekit.common.exceptions import user as _user_exceptions, scopes as _exception_scopes -from flytekit.configuration import ( - platform as _platform_config, internal as _internal_config, sdk as _sdk_config, auth as _auth_config, -) +from flytekit.clients.helpers import iterate_node_executions as _iterate_node_executions +from flytekit.clients.helpers import iterate_task_executions as _iterate_task_executions +from flytekit.common import constants as _constants +from flytekit.common import utils as _common_utils +from flytekit.common.exceptions import scopes as _exception_scopes +from flytekit.common.exceptions import user as _user_exceptions +from flytekit.configuration import auth as _auth_config +from flytekit.configuration import internal as _internal_config +from flytekit.configuration import platform as _platform_config +from flytekit.configuration import sdk as _sdk_config from flytekit.engines import common as _common_engine from flytekit.interfaces.data import data_proxy as _data_proxy from flytekit.interfaces.stats.taggable import get_stats as _get_stats -from flytekit.models import task as _task_models, execution as _execution_models, \ - literals as _literals, common as _common_models -from flytekit.models.admin import common as _common, workflow as _workflow_model -from flytekit.models.core import errors as _error_models, identifier as _identifier +from flytekit.models import common as _common_models +from flytekit.models import execution as _execution_models +from flytekit.models import literals as _literals +from flytekit.models import task as _task_models +from flytekit.models.admin import common as _common +from flytekit.models.admin import workflow as _workflow_model +from flytekit.models.core import errors as _error_models +from flytekit.models.core import identifier as _identifier class _FlyteClientManager(object): @@ -46,7 +53,6 @@ def client(self): class FlyteEngineFactory(_common_engine.BaseExecutionEngineFactory): - def get_workflow(self, sdk_workflow): """ :param flytekit.common.workflow.SdkWorkflow sdk_workflow: @@ -95,8 +101,7 @@ def fetch_workflow_execution(self, wf_exec_id): :rtype: flytekit.models.execution.Execution """ return _FlyteClientManager( - _platform_config.URL.get(), - insecure=_platform_config.INSECURE.get() + _platform_config.URL.get(), insecure=_platform_config.INSECURE.get() ).client.get_execution(wf_exec_id) def fetch_task(self, task_id): @@ -106,8 +111,7 @@ def fetch_task(self, task_id): :rtype: flytekit.models.task.Task """ return _FlyteClientManager( - _platform_config.URL.get(), - insecure=_platform_config.INSECURE.get() + _platform_config.URL.get(), insecure=_platform_config.INSECURE.get() ).client.get_task(task_id) def fetch_latest_task(self, named_task): @@ -119,9 +123,7 @@ def fetch_latest_task(self, named_task): task_list, _ = _FlyteClientManager( _platform_config.URL.get(), insecure=_platform_config.INSECURE.get() ).client.list_tasks_paginated( - named_task, - limit=1, - sort_by=_common.Sort("created_at", _common.Sort.Direction.DESCENDING), + named_task, limit=1, sort_by=_common.Sort("created_at", _common.Sort.Direction.DESCENDING), ) return task_list[0] if task_list else None @@ -133,18 +135,14 @@ def fetch_launch_plan(self, launch_plan_id): """ if launch_plan_id.version: return _FlyteClientManager( - _platform_config.URL.get(), - insecure=_platform_config.INSECURE.get() + _platform_config.URL.get(), insecure=_platform_config.INSECURE.get() ).client.get_launch_plan(launch_plan_id) else: named_entity_id = _common_models.NamedEntityIdentifier( - launch_plan_id.project, - launch_plan_id.domain, - launch_plan_id.name + launch_plan_id.project, launch_plan_id.domain, launch_plan_id.name ) return _FlyteClientManager( - _platform_config.URL.get(), - insecure=_platform_config.INSECURE.get() + _platform_config.URL.get(), insecure=_platform_config.INSECURE.get() ).client.get_active_launch_plan(named_entity_id) def fetch_workflow(self, workflow_id): @@ -154,33 +152,46 @@ def fetch_workflow(self, workflow_id): :rtype: flytekit.models.admin.workflow.Workflow """ return _FlyteClientManager( - _platform_config.URL.get(), - insecure=_platform_config.INSECURE.get() + _platform_config.URL.get(), insecure=_platform_config.INSECURE.get() ).client.get_workflow(workflow_id) class FlyteLaunchPlan(_common_engine.BaseLaunchPlanLauncher): - def register(self, identifier): client = _FlyteClientManager(_platform_config.URL.get(), insecure=_platform_config.INSECURE.get()).client try: - client.create_launch_plan( - identifier, - self.sdk_launch_plan - ) + client.create_launch_plan(identifier, self.sdk_launch_plan) except _user_exceptions.FlyteEntityAlreadyExistsException: pass - @_deprecated(reason="Use launch instead", version='0.9.0') - def execute(self, project, domain, name, inputs, notification_overrides=None, label_overrides=None, - annotation_overrides=None): + @_deprecated(reason="Use launch instead", version="0.9.0") + def execute( + self, + project, + domain, + name, + inputs, + notification_overrides=None, + label_overrides=None, + annotation_overrides=None, + ): """ Deprecated. Use launch instead. """ - return self.launch(project, domain, name, inputs, notification_overrides, label_overrides, annotation_overrides) + return self.launch( + project, domain, name, inputs, notification_overrides, label_overrides, annotation_overrides, + ) - def launch(self, project, domain, name, inputs, notification_overrides=None, label_overrides=None, - annotation_overrides=None): + def launch( + self, + project, + domain, + name, + inputs, + notification_overrides=None, + label_overrides=None, + annotation_overrides=None, + ): """ Creates a workflow execution using parameters specified in the launch plan. :param Text project: @@ -193,13 +204,11 @@ def launch(self, project, domain, name, inputs, notification_overrides=None, lab :param flytekit.models.common.Annotations annotation_overrides: :rtype: flytekit.models.execution.Execution """ - disable_all = (notification_overrides == []) + disable_all = notification_overrides == [] if disable_all: notification_overrides = None else: - notification_overrides = _execution_models.NotificationList( - notification_overrides or [] - ) + notification_overrides = _execution_models.NotificationList(notification_overrides or []) disable_all = None try: @@ -212,8 +221,8 @@ def launch(self, project, domain, name, inputs, notification_overrides=None, lab self.sdk_launch_plan.id, _execution_models.ExecutionMetadata( _execution_models.ExecutionMetadata.ExecutionMode.MANUAL, - 'sdk', # TODO: get principle - 0 # TODO: Detect nesting + "sdk", # TODO: get principle + 0, # TODO: Detect nesting ), notifications=notification_overrides, disable_all=disable_all, @@ -232,39 +241,25 @@ def update(self, identifier, state): :param int state: Enum value from flytekit.models.launch_plan.LaunchPlanState """ return _FlyteClientManager( - _platform_config.URL.get(), - insecure=_platform_config.INSECURE.get() + _platform_config.URL.get(), insecure=_platform_config.INSECURE.get() ).client.update_launch_plan(identifier, state) class FlyteWorkflow(_common_engine.BaseWorkflowExecutor): - def register(self, identifier): client = _FlyteClientManager(_platform_config.URL.get(), insecure=_platform_config.INSECURE.get()).client try: sub_workflows = self.sdk_workflow.get_sub_workflows() - return client.create_workflow( - identifier, - _workflow_model.WorkflowSpec( - self.sdk_workflow, - sub_workflows, - ) - ) + return client.create_workflow(identifier, _workflow_model.WorkflowSpec(self.sdk_workflow, sub_workflows,),) except _user_exceptions.FlyteEntityAlreadyExistsException: pass class FlyteTask(_common_engine.BaseTaskExecutor): - def register(self, identifier): client = _FlyteClientManager(_platform_config.URL.get(), insecure=_platform_config.INSECURE.get()).client try: - client.create_task( - identifier, - _task_models.TaskSpec( - self.sdk_task - ) - ) + client.create_task(identifier, _task_models.TaskSpec(self.sdk_task)) except _user_exceptions.FlyteEntityAlreadyExistsException: pass @@ -294,7 +289,7 @@ def execute(self, inputs, context=None): execution_id=_identifier.WorkflowExecutionIdentifier( project=_internal_config.EXECUTION_PROJECT.get(), domain=_internal_config.EXECUTION_DOMAIN.get(), - name=_internal_config.EXECUTION_NAME.get() + name=_internal_config.EXECUTION_NAME.get(), ), execution_date=_datetime.utcnow(), stats=_get_stats( @@ -304,29 +299,25 @@ def execute(self, inputs, context=None): "{}.{}.{}.user_stats".format( _internal_config.TASK_PROJECT.get() or _internal_config.PROJECT.get(), _internal_config.TASK_DOMAIN.get() or _internal_config.DOMAIN.get(), - _internal_config.TASK_NAME.get() or _internal_config.NAME.get() + _internal_config.TASK_NAME.get() or _internal_config.NAME.get(), ), tags={ - 'exec_project': _internal_config.EXECUTION_PROJECT.get(), - 'exec_domain': _internal_config.EXECUTION_DOMAIN.get(), - 'exec_workflow': _internal_config.EXECUTION_WORKFLOW.get(), - 'exec_launchplan': _internal_config.EXECUTION_LAUNCHPLAN.get(), - 'api_version': _api_version - } + "exec_project": _internal_config.EXECUTION_PROJECT.get(), + "exec_domain": _internal_config.EXECUTION_DOMAIN.get(), + "exec_workflow": _internal_config.EXECUTION_WORKFLOW.get(), + "exec_launchplan": _internal_config.EXECUTION_LAUNCHPLAN.get(), + "api_version": _api_version, + }, ), logging=_logging, - tmp_dir=task_dir + tmp_dir=task_dir, ), - inputs + inputs, ) except _exception_scopes.FlyteScopedException as e: _logging.error("!!! Begin Error Captured by Flyte !!!") output_file_dict[_constants.ERROR_FILE_NAME] = _error_models.ErrorDocument( - _error_models.ContainerError( - e.error_code, - e.verbose_message, - e.kind - ) + _error_models.ContainerError(e.error_code, e.verbose_message, e.kind) ) _logging.error(e.verbose_message) _logging.error("!!! End Error Captured by Flyte !!!") @@ -335,23 +326,29 @@ def execute(self, inputs, context=None): exc_str = _traceback.format_exc() output_file_dict[_constants.ERROR_FILE_NAME] = _error_models.ErrorDocument( _error_models.ContainerError( - "SYSTEM:Unknown", - exc_str, - _error_models.ContainerError.Kind.RECOVERABLE + "SYSTEM:Unknown", exc_str, _error_models.ContainerError.Kind.RECOVERABLE, ) ) _logging.error(exc_str) _logging.error("!!! End Error Captured by Flyte !!!") finally: for k, v in _six.iteritems(output_file_dict): - _common_utils.write_proto_to_file( - v.to_flyte_idl(), - _os.path.join(temp_dir.name, k) - ) - _data_proxy.Data.put_data(temp_dir.name, context['output_prefix'], is_multipart=True) + _common_utils.write_proto_to_file(v.to_flyte_idl(), _os.path.join(temp_dir.name, k)) + _data_proxy.Data.put_data( + temp_dir.name, context["output_prefix"], is_multipart=True, + ) - def launch(self, project, domain, name=None, inputs=None, notification_overrides=None, label_overrides=None, - annotation_overrides=None, auth_role=None): + def launch( + self, + project, + domain, + name=None, + inputs=None, + notification_overrides=None, + label_overrides=None, + annotation_overrides=None, + auth_role=None, + ): """ Executes the task as a single task execution and returns the identifier. :param Text project: @@ -365,13 +362,11 @@ def launch(self, project, domain, name=None, inputs=None, notification_overrides :param flytekit.models.common.AuthRole auth_role: :rtype: flytekit.models.execution.Execution """ - disable_all = (notification_overrides == []) + disable_all = notification_overrides == [] if disable_all: notification_overrides = None else: - notification_overrides = _execution_models.NotificationList( - notification_overrides or [] - ) + notification_overrides = _execution_models.NotificationList(notification_overrides or []) disable_all = None if not auth_role: @@ -379,11 +374,14 @@ def launch(self, project, domain, name=None, inputs=None, notification_overrides kubernetes_service_account = _auth_config.KUBERNETES_SERVICE_ACCOUNT.get() if not (assumable_iam_role or kubernetes_service_account): - _logging.warning("Using deprecated `role` from config. " - "Please update your config to use `assumable_iam_role` instead") + _logging.warning( + "Using deprecated `role` from config. " + "Please update your config to use `assumable_iam_role` instead" + ) assumable_iam_role = _sdk_config.ROLE.get() - auth_role = _common_models.AuthRole(assumable_iam_role=assumable_iam_role, - kubernetes_service_account=kubernetes_service_account) + auth_role = _common_models.AuthRole( + assumable_iam_role=assumable_iam_role, kubernetes_service_account=kubernetes_service_account, + ) try: # TODO(katrogan): Add handling to register the underlying task if it's not already. @@ -396,8 +394,8 @@ def launch(self, project, domain, name=None, inputs=None, notification_overrides self.sdk_task.id, _execution_models.ExecutionMetadata( _execution_models.ExecutionMetadata.ExecutionMode.MANUAL, - 'sdk', # TODO: get principle - 0 # TODO: Detect nesting + "sdk", # TODO: get principle + 0, # TODO: Detect nesting ), notifications=notification_overrides, disable_all=disable_all, @@ -413,7 +411,6 @@ def launch(self, project, domain, name=None, inputs=None, notification_overrides class FlyteWorkflowExecution(_common_engine.BaseWorkflowExecution): - def get_node_executions(self, filters=None): """ :param list[flytekit.models.filters.Filter] filters: @@ -421,8 +418,7 @@ def get_node_executions(self, filters=None): """ client = _FlyteClientManager(_platform_config.URL.get(), insecure=_platform_config.INSECURE.get()).client return { - v.id.node_id: v - for v in _iterate_node_executions(client, self.sdk_workflow_execution.id, filters=filters) + v.id.node_id: v for v in _iterate_node_executions(client, self.sdk_workflow_execution.id, filters=filters) } def sync(self): @@ -467,13 +463,11 @@ def terminate(self, cause): :param Text cause: """ _FlyteClientManager( - _platform_config.URL.get(), - insecure=_platform_config.INSECURE.get() + _platform_config.URL.get(), insecure=_platform_config.INSECURE.get() ).client.terminate_execution(self.sdk_workflow_execution.id, cause) class FlyteNodeExecution(_common_engine.BaseNodeExecution): - def get_task_executions(self): """ :rtype: list[flytekit.common.tasks.executions.SdkTaskExecution] @@ -526,7 +520,6 @@ def sync(self): class FlyteTaskExecution(_common_engine.BaseTaskExecution): - def get_inputs(self): """ :rtype: flytekit.models.literals.LiteralMap @@ -573,8 +566,6 @@ def get_child_executions(self, filters=None): return { v.id.node_id: v for v in _iterate_node_executions( - client, - task_execution_identifier=self.sdk_task_execution.id, - filters=filters + client, task_execution_identifier=self.sdk_task_execution.id, filters=filters, ) } diff --git a/flytekit/engines/loader.py b/flytekit/engines/loader.py index c896c908af..9471e55ce1 100644 --- a/flytekit/engines/loader.py +++ b/flytekit/engines/loader.py @@ -1,12 +1,14 @@ from __future__ import absolute_import -from flytekit.common.exceptions import user as _user_exceptions, scopes as _exception_scopes -from flytekit.configuration import sdk as _sdk_config + import importlib as _importlib +from flytekit.common.exceptions import scopes as _exception_scopes +from flytekit.common.exceptions import user as _user_exceptions +from flytekit.configuration import sdk as _sdk_config _ENGINE_NAME_TO_MODULES_CACHE = { - 'flyte': ('flytekit.engines.flyte.engine', 'FlyteEngineFactory', None), - 'unit': ('flytekit.engines.unit.engine', 'UnitTestEngineFactory', None), + "flyte": ("flytekit.engines.flyte.engine", "FlyteEngineFactory", None), + "unit": ("flytekit.engines.unit.engine", "UnitTestEngineFactory", None), # 'local': ('flytekit.engines.local.engine', 'EngineObjectFactory', None) } @@ -23,9 +25,8 @@ def get_engine(engine_name=None): raise _user_exceptions.FlyteValueException( engine_name, "Could not load an engine with the identifier '{}'. Known engines are: {}".format( - engine_name, - list(_ENGINE_NAME_TO_MODULES_CACHE.keys()) - ) + engine_name, list(_ENGINE_NAME_TO_MODULES_CACHE.keys()) + ), ) module_path, attr, engine_impl = _ENGINE_NAME_TO_MODULES_CACHE[engine_name] @@ -36,9 +37,7 @@ def get_engine(engine_name=None): raise _user_exceptions.FlyteValueException( module, "Failed to load the engine because the attribute named '{}' could not be found" - "in the module '{}'.".format( - attr, module_path - ) + "in the module '{}'.".format(attr, module_path), ) engine_impl = getattr(module, attr)() diff --git a/flytekit/engines/unit/engine.py b/flytekit/engines/unit/engine.py index ec6ffd1af5..95a2f17b1e 100644 --- a/flytekit/engines/unit/engine.py +++ b/flytekit/engines/unit/engine.py @@ -2,26 +2,29 @@ import logging as _logging import os as _os +from datetime import datetime as _datetime import six as _six -from datetime import datetime as _datetime +from flyteidl.plugins import qubole_pb2 as _qubole_pb2 +from google.protobuf.json_format import ParseDict as _ParseDict from six import moves as _six_moves -from google.protobuf.json_format import ParseDict as _ParseDict -from flyteidl.plugins import qubole_pb2 as _qubole_pb2 -from flytekit.common import constants as _sdk_constants, utils as _common_utils -from flytekit.common.exceptions import user as _user_exceptions, system as _system_exception +from flytekit.common import constants as _sdk_constants +from flytekit.common import utils as _common_utils +from flytekit.common.exceptions import system as _system_exception +from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.types import helpers as _type_helpers from flytekit.configuration import TemporaryConfiguration as _TemporaryConfiguration from flytekit.engines import common as _common_engine from flytekit.engines.unit.mock_stats import MockStats from flytekit.interfaces.data import data_proxy as _data_proxy -from flytekit.models import literals as _literals, array_job as _array_job, qubole as _qubole_models +from flytekit.models import array_job as _array_job +from flytekit.models import literals as _literals +from flytekit.models import qubole as _qubole_models from flytekit.models.core.identifier import WorkflowExecutionIdentifier class UnitTestEngineFactory(_common_engine.BaseExecutionEngineFactory): - def get_task(self, sdk_task): """ :param flytekit.common.tasks.task.SdkTask sdk_task: @@ -43,9 +46,7 @@ def get_task(self, sdk_task): return HiveTask(sdk_task) else: raise _user_exceptions.FlyteAssertion( - "Unit tests are not currently supported for tasks of type: {}".format( - sdk_task.type - ) + "Unit tests are not currently supported for tasks of type: {}".format(sdk_task.type) ) def get_workflow(self, _): @@ -88,8 +89,7 @@ def execute(self, inputs, context=None): :rtype: dict[Text,flytekit.models.common.FlyteIdlEntity] """ with _TemporaryConfiguration( - _os.path.join(_os.path.dirname(__file__), 'unit.config'), - internal_overrides={'image': 'unit_image'} + _os.path.join(_os.path.dirname(__file__), "unit.config"), internal_overrides={"image": "unit_image"}, ): with _common_utils.AutoDeletingTempDir("unit_test_dir") as working_directory: with _data_proxy.LocalWorkingDirectoryContext(working_directory): @@ -103,17 +103,13 @@ def _execute_user_code(self, inputs): with _common_utils.AutoDeletingTempDir("user_dir") as user_working_directory: return self.sdk_task.execute( _common_engine.EngineContext( - execution_id=WorkflowExecutionIdentifier( - project='unit_test', - domain='unit_test', - name='unit_test' - ), + execution_id=WorkflowExecutionIdentifier(project="unit_test", domain="unit_test", name="unit_test"), execution_date=_datetime.utcnow(), stats=MockStats(), logging=_logging, # TODO: A mock logging object that we can read later. - tmp_dir=user_working_directory + tmp_dir=user_working_directory, ), - inputs + inputs, ) def _transform_for_user_output(self, outputs): @@ -128,8 +124,17 @@ def _transform_for_user_output(self, outputs): def register(self, identifier, version): raise _user_exceptions.FlyteAssertion("You cannot register unit test tasks.") - def launch(self, project, domain, name=None, inputs=None, notification_overrides=None, label_overrides=None, - annotation_overrides=None, auth_role=None): + def launch( + self, + project, + domain, + name=None, + inputs=None, + notification_overrides=None, + label_overrides=None, + annotation_overrides=None, + auth_role=None, + ): raise _user_exceptions.FlyteAssertion("You cannot launch unit test tasks.") @@ -143,15 +148,13 @@ def _transform_for_user_output(self, outputs): literal_map = outputs[_sdk_constants.OUTPUT_FILE_NAME] return { name: _type_helpers.get_sdk_value_from_literal( - literal_map.literals[name], - sdk_type=_type_helpers.get_sdk_type_from_literal_type(variable.type) + literal_map.literals[name], sdk_type=_type_helpers.get_sdk_type_from_literal_type(variable.type), ).to_python_std() for name, variable in _six.iteritems(self.sdk_task.interface.outputs) } class DynamicTask(ReturnOutputsTask): - def __init__(self, *args, **kwargs): self._has_workflow_node = False super(DynamicTask, self).__init__(*args, **kwargs) @@ -179,10 +182,12 @@ def _execute_user_code(self, inputs): for future_node in futures.nodes: if future_node.workflow_node is not None: # TODO: implement proper unit testing for launchplan and subworkflow nodes somehow - _logging.warning("A workflow node has been detected in the output of the dynamic task. The " - "Flytekit unit test engine is incomplete for dynamic tasks that return launch " - "plans or subworkflows. The generated dynamic job spec will be returned but " - "they will not be run.") + _logging.warning( + "A workflow node has been detected in the output of the dynamic task. The " + "Flytekit unit test engine is incomplete for dynamic tasks that return launch " + "plans or subworkflows. The generated dynamic job spec will be returned but " + "they will not be run." + ) # For now, just return the output of the parent task self._has_workflow_node = True return results @@ -202,7 +207,9 @@ def _execute_user_code(self, inputs): if inputs_path not in results: raise _system_exception.FlyteSystemAssertion( "dynamic task hasn't generated expected inputs document [{}] found {}".format( - future_node.id, list(results.keys()))) + future_node.id, list(results.keys()) + ) + ) sub_task_output = UnitTestEngineFactory().get_task(task).execute(results[inputs_path]) sub_task_outputs[future_node.id] = sub_task_output @@ -226,19 +233,20 @@ def execute_array_task(root_input_path, task, array_inputs): array_job = _array_job.ArrayJob.from_dict(task.custom) outputs = {} for job_index in _six_moves.range(0, array_job.size): - inputs_path = _os.path.join(root_input_path, _six.text_type(job_index), _sdk_constants.INPUT_FILE_NAME) + inputs_path = _os.path.join(root_input_path, _six.text_type(job_index), _sdk_constants.INPUT_FILE_NAME,) if inputs_path not in array_inputs: raise _system_exception.FlyteSystemAssertion( - "dynamic task hasn't generated expected inputs document [{}].".format(inputs_path)) + "dynamic task hasn't generated expected inputs document [{}].".format(inputs_path) + ) input_proto = array_inputs[inputs_path] # All outputs generated by the same array job will have the same key in sub_task_outputs, # they will, however, differ in the var names; they will be on the format []. # e.g. [1].out1 for key, val in _six.iteritems( - ReturnOutputsTask( - task.assign_type_and_return(_sdk_constants.SdkTaskType.PYTHON_TASK) # TODO: This is weird - ).execute(input_proto) + ReturnOutputsTask( + task.assign_type_and_return(_sdk_constants.SdkTaskType.PYTHON_TASK) # TODO: This is weird + ).execute(input_proto) ): outputs["[{}].{}".format(job_index, key)] = val return outputs @@ -256,27 +264,37 @@ def fulfil_bindings(binding_data, fulfilled_promises): if binding_data.scalar: return _literals.Literal(scalar=binding_data.scalar) elif binding_data.collection: - return _literals.Literal(collection=_literals.LiteralCollection( - [DynamicTask.fulfil_bindings(sub_binding_data, fulfilled_promises) for sub_binding_data in - binding_data.collection.bindings])) + return _literals.Literal( + collection=_literals.LiteralCollection( + [ + DynamicTask.fulfil_bindings(sub_binding_data, fulfilled_promises) + for sub_binding_data in binding_data.collection.bindings + ] + ) + ) elif binding_data.promise: if binding_data.promise.node_id not in fulfilled_promises: raise _system_exception.FlyteSystemAssertion( - "Expecting output of node [{}] but that hasn't been produced.".format(binding_data.promise.node_id)) + "Expecting output of node [{}] but that hasn't been produced.".format(binding_data.promise.node_id) + ) node_output = fulfilled_promises[binding_data.promise.node_id] if binding_data.promise.var not in node_output: raise _system_exception.FlyteSystemAssertion( "Expecting output [{}] of node [{}] but that hasn't been produced.".format( - binding_data.promise.var, - binding_data.promise.node_id)) + binding_data.promise.var, binding_data.promise.node_id + ) + ) return binding_data.promise.sdk_type.from_python_std(node_output[binding_data.promise.var]) elif binding_data.map: - return _literals.Literal(map=_literals.LiteralMap( - { - k: DynamicTask.fulfil_bindings(sub_binding_data, fulfilled_promises) for k, sub_binding_data in - _six.iteritems(binding_data.map.bindings) - })) + return _literals.Literal( + map=_literals.LiteralMap( + { + k: DynamicTask.fulfil_bindings(sub_binding_data, fulfilled_promises) + for k, sub_binding_data in _six.iteritems(binding_data.map.bindings) + } + ) + ) class HiveTask(DynamicTask): @@ -296,9 +314,7 @@ def _transform_for_user_output(self, outputs): for t in futures.tasks } for node in futures.nodes: - queries.append( - task_ids_to_defs[node.task_node.reference_id.name].query.query - ) + queries.append(task_ids_to_defs[node.task_node.reference_id.name].query.query) return queries else: return [] diff --git a/flytekit/engines/unit/mock_stats.py b/flytekit/engines/unit/mock_stats.py index dab1f52a8b..2239156080 100644 --- a/flytekit/engines/unit/mock_stats.py +++ b/flytekit/engines/unit/mock_stats.py @@ -1,7 +1,7 @@ from __future__ import absolute_import -import logging import datetime as _datetime +import logging class MockStats(object): diff --git a/flytekit/interfaces/data/common.py b/flytekit/interfaces/data/common.py index 6968463429..0491c1476f 100644 --- a/flytekit/interfaces/data/common.py +++ b/flytekit/interfaces/data/common.py @@ -1,5 +1,7 @@ from __future__ import absolute_import + import abc as _abc + import six as _six diff --git a/flytekit/interfaces/data/data_proxy.py b/flytekit/interfaces/data/data_proxy.py index f2d9f137d2..e51659cdd0 100644 --- a/flytekit/interfaces/data/data_proxy.py +++ b/flytekit/interfaces/data/data_proxy.py @@ -1,13 +1,16 @@ from __future__ import absolute_import -from flytekit.configuration import sdk as _sdk_config, platform as _platform_config -from flytekit.interfaces.data.s3 import s3proxy as _s3proxy +import six as _six + +from flytekit.common import constants as _constants +from flytekit.common import utils as _common_utils +from flytekit.common.exceptions import user as _user_exception +from flytekit.configuration import platform as _platform_config +from flytekit.configuration import sdk as _sdk_config from flytekit.interfaces.data.gcs import gcs_proxy as _gcs_proxy -from flytekit.interfaces.data.local import local_file_proxy as _local_file_proxy from flytekit.interfaces.data.http import http_data_proxy as _http_data_proxy -from flytekit.common.exceptions import user as _user_exception -from flytekit.common import utils as _common_utils, constants as _constants -import six as _six +from flytekit.interfaces.data.local import local_file_proxy as _local_file_proxy +from flytekit.interfaces.data.s3 import s3proxy as _s3proxy class LocalWorkingDirectoryContext(object): @@ -74,8 +77,7 @@ def __init__(self, cloud_provider=None): if proxy is None: raise _user_exception.FlyteAssertion( "Configured cloud provider is not supported for data I/O. Received: {}, expected one of: {}".format( - cloud_provider, - list(type(self)._CLOUD_PROVIDER_TO_PROXIES.keys()) + cloud_provider, list(type(self)._CLOUD_PROVIDER_TO_PROXIES.keys()) ) ) super(RemoteDataContext, self).__init__(proxy) @@ -132,7 +134,7 @@ def get_data(cls, remote_path, local_path, is_multipart=False): remote_path=remote_path, local_path=local_path, is_multipart=is_multipart, - error_string=_six.text_type(ex) + error_string=_six.text_type(ex), ) ) @@ -157,7 +159,7 @@ def put_data(cls, local_path, remote_path, is_multipart=False): remote_path=remote_path, local_path=local_path, is_multipart=is_multipart, - error_string=_six.text_type(ex) + error_string=_six.text_type(ex), ) ) diff --git a/flytekit/interfaces/data/gcs/gcs_proxy.py b/flytekit/interfaces/data/gcs/gcs_proxy.py index fd024e99b3..2313b38136 100644 --- a/flytekit/interfaces/data/gcs/gcs_proxy.py +++ b/flytekit/interfaces/data/gcs/gcs_proxy.py @@ -4,12 +4,11 @@ import sys as _sys import uuid as _uuid +from flytekit.common.exceptions.user import FlyteUserException as _FlyteUserException from flytekit.configuration import gcp as _gcp_config from flytekit.interfaces import random as _flyte_random from flytekit.interfaces.data import common as _common_data from flytekit.tools import subprocess as _subprocess -from flytekit.common.exceptions.user import FlyteUserException as _FlyteUserException - if _sys.version_info >= (3,): from shutil import which as _which @@ -116,10 +115,7 @@ def upload_directory(self, local_path, remote_path): GCSProxy._check_binary() cmd = self._maybe_with_gsutil_parallelism( - "cp", - "-r", - _amend_path(local_path), - remote_path if remote_path.endswith("/") else remote_path + "/" + "cp", "-r", _amend_path(local_path), remote_path if remote_path.endswith("/") else remote_path + "/", ) return _update_cmd_config_and_execute(cmd) diff --git a/flytekit/interfaces/data/http/http_data_proxy.py b/flytekit/interfaces/data/http/http_data_proxy.py index a0d7e4f06d..bc648c4511 100644 --- a/flytekit/interfaces/data/http/http_data_proxy.py +++ b/flytekit/interfaces/data/http/http_data_proxy.py @@ -1,8 +1,9 @@ from __future__ import absolute_import import requests as _requests -from flytekit.interfaces.data import common as _common_data + from flytekit.common.exceptions import user as _user_exceptions +from flytekit.interfaces.data import common as _common_data class HttpFileProxy(_common_data.DataProxy): @@ -17,14 +18,15 @@ def exists(self, path): :rtype bool: whether the file exists or not """ rsp = _requests.head(path) - allowed_codes = {type(self)._HTTP_OK, type(self)._HTTP_NOT_FOUND, type(self)._HTTP_FORBIDDEN} + allowed_codes = { + type(self)._HTTP_OK, + type(self)._HTTP_NOT_FOUND, + type(self)._HTTP_FORBIDDEN, + } if rsp.status_code not in allowed_codes: raise _user_exceptions.FlyteValueException( rsp.status_code, - "Data at {} could not be checked for existence. Expected one of: {}".format( - path, - allowed_codes - ) + "Data at {} could not be checked for existence. Expected one of: {}".format(path, allowed_codes), ) return rsp.status_code == type(self)._HTTP_OK @@ -45,9 +47,9 @@ def download(self, from_path, to_path): if rsp.status_code != type(self)._HTTP_OK: raise _user_exceptions.FlyteValueException( rsp.status_code, - "Request for data @ {} failed. Expected status code {}".format(from_path, type(self)._HTTP_OK) + "Request for data @ {} failed. Expected status code {}".format(from_path, type(self)._HTTP_OK), ) - with open(to_path, 'wb') as writer: + with open(to_path, "wb") as writer: writer.write(rsp.content) def upload(self, from_path, to_path): diff --git a/flytekit/interfaces/data/local/local_file_proxy.py b/flytekit/interfaces/data/local/local_file_proxy.py index 3c457d8676..77c70503a0 100644 --- a/flytekit/interfaces/data/local/local_file_proxy.py +++ b/flytekit/interfaces/data/local/local_file_proxy.py @@ -4,8 +4,9 @@ import uuid as _uuid from distutils import dir_util as _dir_util from shutil import copyfile as _copyfile -from flytekit.interfaces.data import common as _common_data + from flytekit.interfaces import random as _flyte_random +from flytekit.interfaces.data import common as _common_data def _make_local_path(path): @@ -18,7 +19,6 @@ def _make_local_path(path): class LocalFileProxy(_common_data.DataProxy): - def __init__(self, sandbox): """ :param Text sandbox: diff --git a/flytekit/interfaces/data/s3/s3proxy.py b/flytekit/interfaces/data/s3/s3proxy.py index 39cdc4d9d6..8ed560b35d 100644 --- a/flytekit/interfaces/data/s3/s3proxy.py +++ b/flytekit/interfaces/data/s3/s3proxy.py @@ -5,14 +5,15 @@ import string as _string import sys as _sys import uuid as _uuid -from six import moves as _six_moves, text_type as _text_type +from six import moves as _six_moves +from six import text_type as _text_type + +from flytekit.common.exceptions.user import FlyteUserException as _FlyteUserException from flytekit.configuration import aws as _aws_config from flytekit.interfaces import random as _flyte_random from flytekit.interfaces.data import common as _common_data from flytekit.tools import subprocess as _subprocess -from flytekit.common.exceptions.user import FlyteUserException as _FlyteUserException - if _sys.version_info >= (3,): from shutil import which as _which @@ -46,7 +47,7 @@ def _check_binary(): Make sure that the AWS cli is present """ if not _which(AwsS3Proxy._AWS_CLI): - raise _FlyteUserException('AWS CLI not found at Please install.') + raise _FlyteUserException("AWS CLI not found at Please install.") @staticmethod def _split_s3_path_to_bucket_and_key(path): @@ -54,9 +55,9 @@ def _split_s3_path_to_bucket_and_key(path): :param Text path: :rtype: (Text, Text) """ - path = path[len("s3://"):] - first_slash = path.index('/') - return path[:first_slash], path[first_slash + 1:] + path = path[len("s3://") :] + first_slash = path.index("/") + return path[:first_slash], path[first_slash + 1 :] def exists(self, remote_path): """ @@ -69,7 +70,15 @@ def exists(self, remote_path): raise ValueError("Not an S3 ARN. Please use FQN (S3 ARN) of the format s3://...") bucket, file_path = self._split_s3_path_to_bucket_and_key(remote_path) - cmd = [AwsS3Proxy._AWS_CLI, "s3api", "head-object", "--bucket", bucket, "--key", file_path] + cmd = [ + AwsS3Proxy._AWS_CLI, + "s3api", + "head-object", + "--bucket", + bucket, + "--key", + file_path, + ] try: _update_cmd_config_and_execute(cmd) return True @@ -78,7 +87,7 @@ def exists(self, remote_path): # the http status code: "An error occurred (404) when calling the HeadObject operation: Not Found" # This is a best effort for returning if the object does not exist by searching # for existence of (404) in the error message. This should not be needed when we get off the cli and use lib - if _re.search('(404)', _text_type(ex)): + if _re.search("(404)", _text_type(ex)): return False else: raise ex @@ -116,7 +125,7 @@ def upload(self, file_path, to_path): AwsS3Proxy._check_binary() extra_args = { - 'ACL': 'bucket-owner-full-control', + "ACL": "bucket-owner-full-control", } cmd = [AwsS3Proxy._AWS_CLI, "s3", "cp"] @@ -136,7 +145,7 @@ def upload_directory(self, local_path, remote_path): :param Text remote_path: """ extra_args = { - 'ACL': 'bucket-owner-full-control', + "ACL": "bucket-owner-full-control", } if not remote_path.startswith("s3://"): diff --git a/flytekit/interfaces/random.py b/flytekit/interfaces/random.py index 114a2b9ea7..8a112c95f8 100644 --- a/flytekit/interfaces/random.py +++ b/flytekit/interfaces/random.py @@ -1,4 +1,5 @@ from __future__ import absolute_import + import random as _random random = _random.Random() diff --git a/flytekit/interfaces/stats/client.py b/flytekit/interfaces/stats/client.py index 5710209c55..6c2fea3afa 100644 --- a/flytekit/interfaces/stats/client.py +++ b/flytekit/interfaces/stats/client.py @@ -10,10 +10,11 @@ from flytekit.configuration import statsd as _statsd_config RESERVED_TAG_WORDS = frozenset( - ['asg', 'az', 'backend', 'canary', 'host', 'period', 'region', 'shard', 'window', 'source']) + ["asg", "az", "backend", "canary", "host", "period", "region", "shard", "window", "source"] +) # TODO should this be a whitelist instead? -FORBIDDEN_TAG_VALUE_CHARACTERS = re.compile('[|.:]') +FORBIDDEN_TAG_VALUE_CHARACTERS = re.compile("[|.:]") _stats_client = None @@ -39,29 +40,31 @@ def __init__(self, client, prefix=None): for extendable_func in self.EXTENDABLE_FUNC: base_func = getattr(self._client, extendable_func) if base_func: - setattr(self, - extendable_func, - self._create_wrapped_function(base_func)) + setattr(self, extendable_func, self._create_wrapped_function(base_func)) def _create_wrapped_function(self, base_func): if self._scope_prefix: + def name_wrap(stat, *args, **kwargs): - tags = kwargs.pop('tags', {}) - if kwargs.pop('per_host', False): - tags['_f'] = 'i' + tags = kwargs.pop("tags", {}) + if kwargs.pop("per_host", False): + tags["_f"] = "i" if bool(tags): stat = self._serialize_tags(stat, tags) return base_func(self._p_with_prefix(stat), *args, **kwargs) + else: + def name_wrap(stat, *args, **kwargs): - tags = kwargs.pop('tags', {}) - if kwargs.pop('per_host', False): - tags['_f'] = 'i' + tags = kwargs.pop("tags", {}) + if kwargs.pop("per_host", False): + tags["_f"] = "i" if bool(tags): stat = self._serialize_tags(stat, tags) return base_func(stat, *args, **kwargs) + return name_wrap def get_stats(self, name): @@ -73,8 +76,7 @@ def get_stats(self, name): return ScopeableStatsProxy(self._client, prefix) def pipeline(self): - return ScopeableStatsProxy(self._client.pipeline(), - self._scope_prefix) + return ScopeableStatsProxy(self._client.pipeline(), self._scope_prefix) def _p_with_prefix(self, name): if name is None: @@ -85,7 +87,7 @@ def _is_ascii(self, name): if sys.version_info >= (3, 7): return name.isascii() try: - return name and name.encode('ascii') + return name and name.encode("ascii") except UnicodeEncodeError: return False @@ -101,7 +103,7 @@ def _serialize_tags(self, metric, tags=None): # if the tags fail sanity check we will not serialize the tags, and simply return the metric. if not self._is_ascii(key): return stat - tag = FORBIDDEN_TAG_VALUE_CHARACTERS.sub('_', str(tags[key])) + tag = FORBIDDEN_TAG_VALUE_CHARACTERS.sub("_", str(tags[key])) if tag != "": metric += ".__{0}={1}".format(key, tag) except UnicodeEncodeError: @@ -117,8 +119,7 @@ def __getattr__(self, name): return getattr(self._client, name) def __enter__(self): - return ScopeableStatsProxy(self._client.__enter__(), - self._scope_prefix) + return ScopeableStatsProxy(self._client.__enter__(), self._scope_prefix) def __exit__(self, exc_type, exc_value, traceback): self._client.__exit__(exc_type, exc_value, traceback) diff --git a/flytekit/interfaces/stats/taggable.py b/flytekit/interfaces/stats/taggable.py index f874f2472a..12a78488dc 100644 --- a/flytekit/interfaces/stats/taggable.py +++ b/flytekit/interfaces/stats/taggable.py @@ -14,25 +14,29 @@ def __init__(self, client, full_prefix, prefix=None, tags=None): def _create_wrapped_function(self, base_func): if self._scope_prefix: + def name_wrap(stat, *args, **kwargs): - tags = kwargs.pop('tags', {}) + tags = kwargs.pop("tags", {}) tags.update(self._tags) - if kwargs.pop('per_host', False): - tags['_f'] = 'i' + if kwargs.pop("per_host", False): + tags["_f"] = "i" if bool(tags): stat = self._serialize_tags(stat, tags) return base_func(self._p_with_prefix(stat), *args, **kwargs) + else: + def name_wrap(stat, *args, **kwargs): - tags = kwargs.pop('tags', {}) + tags = kwargs.pop("tags", {}) tags.update(self._tags) - if kwargs.pop('per_host', False): - tags['_f'] = 'i' + if kwargs.pop("per_host", False): + tags["_f"] = "i" if bool(tags): stat = self._serialize_tags(stat, tags) return base_func(stat, *args, **kwargs) + return name_wrap def clear_tags(self): @@ -43,17 +47,13 @@ def extend_tags(self, tags): def pipeline(self): return TaggableStats( - self._client.pipeline(), - self._full_prefix, - prefix=self._scope_prefix, - tags=dict(self._tags)) + self._client.pipeline(), self._full_prefix, prefix=self._scope_prefix, tags=dict(self._tags), + ) def __enter__(self): return TaggableStats( - self._client.__enter__(), - self._full_prefix, - prefix=self._scope_prefix, - tags=dict(self._tags)) + self._client.__enter__(), self._full_prefix, prefix=self._scope_prefix, tags=dict(self._tags), + ) def get_stats(self, name, copy_tags=True): if not self._scope_prefix or self._scope_prefix == "": @@ -62,7 +62,7 @@ def get_stats(self, name, copy_tags=True): prefix = self._scope_prefix + "." + name if self._full_prefix: - full_prefix = self._full_prefix + '.' + prefix + full_prefix = self._full_prefix + "." + prefix else: full_prefix = prefix diff --git a/flytekit/models/admin/common.py b/flytekit/models/admin/common.py index 762f0e9449..0a0cdde0c1 100644 --- a/flytekit/models/admin/common.py +++ b/flytekit/models/admin/common.py @@ -1,11 +1,11 @@ from __future__ import absolute_import -from flytekit.models import common as _common from flyteidl.admin import common_pb2 as _common_pb2 +from flytekit.models import common as _common + class Sort(_common.FlyteIdlEntity): - class Direction(object): DESCENDING = _common_pb2.Sort.DESCENDING ASCENDING = _common_pb2.Sort.ASCENDING @@ -53,16 +53,20 @@ def from_python_std(cls, text): :rtype: Sort """ text = text.strip() - if text[-1] != ')': - raise ValueError("Could not parse string. Must be in format 'asc(key)' or 'desc(key)'. '{}' did not " - "end with ')'.".format(text)) + if text[-1] != ")": + raise ValueError( + "Could not parse string. Must be in format 'asc(key)' or 'desc(key)'. '{}' did not " + "end with ')'.".format(text) + ) if text.startswith("asc("): direction = Sort.Direction.ASCENDING - key = text[len("asc("):-1].strip() + key = text[len("asc(") : -1].strip() elif text.startswith("desc("): direction = Sort.Direction.DESCENDING - key = text[len("desc("):-1].strip() + key = text[len("desc(") : -1].strip() else: - raise ValueError("Could not parse string. Must be in format 'asc(key)' or 'desc(key)'. '{}' did not " - "start with 'asc(' or 'desc'.".format(text)) + raise ValueError( + "Could not parse string. Must be in format 'asc(key)' or 'desc(key)'. '{}' did not " + "start with 'asc(' or 'desc'.".format(text) + ) return cls(key=key, direction=direction) diff --git a/flytekit/models/admin/task_execution.py b/flytekit/models/admin/task_execution.py index bb54273f1b..8100e9343f 100644 --- a/flytekit/models/admin/task_execution.py +++ b/flytekit/models/admin/task_execution.py @@ -1,12 +1,16 @@ from __future__ import absolute_import -from flytekit.models import common as _common -from flytekit.models.core import identifier as _identifier, execution as _execution + from flyteidl.admin import task_execution_pb2 as _task_execution_pb2 +from flytekit.models import common as _common +from flytekit.models.core import execution as _execution +from flytekit.models.core import identifier as _identifier + class TaskExecutionClosure(_common.FlyteIdlEntity): - - def __init__(self, phase, logs, started_at, duration, created_at, updated_at, output_uri=None, error=None): + def __init__( + self, phase, logs, started_at, duration, created_at, updated_at, output_uri=None, error=None, + ): """ :param int phase: Enum value from flytekit.models.core.execution.TaskExecutionPhase :param list[flytekit.models.core.execution.TaskLog] logs: List of all logs associated with the execution. @@ -93,7 +97,7 @@ def to_flyte_idl(self): phase=self.phase, logs=[l.to_flyte_idl() for l in self.logs], output_uri=self.output_uri, - error=self.error.to_flyte_idl() if self.error is not None else None + error=self.error.to_flyte_idl() if self.error is not None else None, ) p.started_at.FromDatetime(self.started_at) p.created_at.FromDatetime(self.created_at) @@ -115,12 +119,11 @@ def from_flyte_idl(cls, p): started_at=p.started_at.ToDatetime(), created_at=p.created_at.ToDatetime(), updated_at=p.updated_at.ToDatetime(), - duration=p.duration.ToTimedelta() + duration=p.duration.ToTimedelta(), ) class TaskExecution(_common.FlyteIdlEntity): - def __init__(self, id, input_uri, closure, is_parent): """ :param flytekit.models.core.identifier.TaskExecutionIdentifier id: @@ -169,7 +172,7 @@ def to_flyte_idl(self): id=self.id.to_flyte_idl(), input_uri=self.input_uri, closure=self.closure.to_flyte_idl(), - is_parent=self.is_parent + is_parent=self.is_parent, ) @classmethod diff --git a/flytekit/models/admin/workflow.py b/flytekit/models/admin/workflow.py index 8d647f4779..8a5f87e62a 100644 --- a/flytekit/models/admin/workflow.py +++ b/flytekit/models/admin/workflow.py @@ -1,11 +1,14 @@ from __future__ import absolute_import -from flytekit.models import common as _common -from flytekit.models.core import compiler as _compiler_models, identifier as _identifier, workflow as _core_workflow + from flyteidl.admin import workflow_pb2 as _admin_workflow +from flytekit.models import common as _common +from flytekit.models.core import compiler as _compiler_models +from flytekit.models.core import identifier as _identifier +from flytekit.models.core import workflow as _core_workflow -class WorkflowSpec(_common.FlyteIdlEntity): +class WorkflowSpec(_common.FlyteIdlEntity): def __init__(self, template, sub_workflows): """ This object fully encapsulates the specification of a workflow @@ -34,8 +37,7 @@ def to_flyte_idl(self): :rtype: flyteidl.admin.workflow_pb2.WorkflowSpec """ return _admin_workflow.WorkflowSpec( - template=self._template.to_flyte_idl(), - sub_workflows=[s.to_flyte_idl() for s in self._sub_workflows], + template=self._template.to_flyte_idl(), sub_workflows=[s.to_flyte_idl() for s in self._sub_workflows], ) @classmethod @@ -44,17 +46,14 @@ def from_flyte_idl(cls, pb2_object): :param pb2_object: flyteidl.admin.workflow_pb2.WorkflowSpec :rtype: WorkflowSpec """ - return cls(_core_workflow.WorkflowTemplate.from_flyte_idl(pb2_object.template), - [_core_workflow.WorkflowTemplate.from_flyte_idl(s) for s in pb2_object.sub_workflows]) + return cls( + _core_workflow.WorkflowTemplate.from_flyte_idl(pb2_object.template), + [_core_workflow.WorkflowTemplate.from_flyte_idl(s) for s in pb2_object.sub_workflows], + ) class Workflow(_common.FlyteIdlEntity): - - def __init__( - self, - id, - closure - ): + def __init__(self, id, closure): """ :param flytekit.models.core.identifier.Identifier id: :param WorkflowClosure closure: @@ -80,10 +79,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.workflow_pb2.Workflow """ - return _admin_workflow.Workflow( - id=self.id.to_flyte_idl(), - closure=self.closure.to_flyte_idl() - ) + return _admin_workflow.Workflow(id=self.id.to_flyte_idl(), closure=self.closure.to_flyte_idl()) @classmethod def from_flyte_idl(cls, pb2_object): @@ -93,12 +89,11 @@ def from_flyte_idl(cls, pb2_object): """ return cls( id=_identifier.Identifier.from_flyte_idl(pb2_object.id), - closure=WorkflowClosure.from_flyte_idl(pb2_object.closure) + closure=WorkflowClosure.from_flyte_idl(pb2_object.closure), ) class WorkflowClosure(_common.FlyteIdlEntity): - def __init__(self, compiled_workflow): """ :param flytekit.models.core.compiler.CompiledWorkflowClosure compiled_workflow: @@ -116,9 +111,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.workflow_pb2.WorkflowClosure """ - return _admin_workflow.WorkflowClosure( - compiled_workflow=self.compiled_workflow.to_flyte_idl() - ) + return _admin_workflow.WorkflowClosure(compiled_workflow=self.compiled_workflow.to_flyte_idl()) @classmethod def from_flyte_idl(cls, p): @@ -126,6 +119,4 @@ def from_flyte_idl(cls, p): :param flyteidl.admin.workflow_pb2.WorkflowClosure p: :rtype: WorkflowClosure """ - return cls( - compiled_workflow=_compiler_models.CompiledWorkflowClosure.from_flyte_idl(p.compiled_workflow) - ) + return cls(compiled_workflow=_compiler_models.CompiledWorkflowClosure.from_flyte_idl(p.compiled_workflow)) diff --git a/flytekit/models/array_job.py b/flytekit/models/array_job.py index 0a73de7b9a..a420f289f2 100644 --- a/flytekit/models/array_job.py +++ b/flytekit/models/array_job.py @@ -9,7 +9,6 @@ class ArrayJob(_common.FlyteCustomIdlEntity): - def __init__(self, parallelism, size, min_successes): """ Initializes a new ArrayJob. @@ -64,11 +63,9 @@ def to_dict(self): """ :rtype: dict[T, Text] """ - return _json_format.MessageToDict(_array_job.ArrayJob( - parallelism=self.parallelism, - size=self.size, - min_successes=self.min_successes - )) + return _json_format.MessageToDict( + _array_job.ArrayJob(parallelism=self.parallelism, size=self.size, min_successes=self.min_successes,) + ) @classmethod def from_dict(cls, idl_dict): @@ -78,8 +75,4 @@ def from_dict(cls, idl_dict): """ pb2_object = _json_format.Parse(_json.dumps(idl_dict), _array_job.ArrayJob()) - return cls( - parallelism=pb2_object.parallelism, - size=pb2_object.size, - min_successes=pb2_object.min_successes - ) + return cls(parallelism=pb2_object.parallelism, size=pb2_object.size, min_successes=pb2_object.min_successes,) diff --git a/flytekit/models/common.py b/flytekit/models/common.py index 8d67ce4b56..e1a4808e58 100644 --- a/flytekit/models/common.py +++ b/flytekit/models/common.py @@ -5,7 +5,8 @@ import six as _six from flyteidl.admin import common_pb2 as _common_pb2 -from google.protobuf import json_format as _json_format, struct_pb2 as _struct +from google.protobuf import json_format as _json_format +from google.protobuf import struct_pb2 as _struct class FlyteABCMeta(_abc.ABCMeta): @@ -16,7 +17,6 @@ def __instancecheck__(cls, instance): class FlyteType(FlyteABCMeta): - def __repr__(cls): return cls.short_class_string() @@ -41,10 +41,8 @@ def from_flyte_idl(cls, idl_object): class FlyteIdlEntity(_six.with_metaclass(FlyteType, object)): - def __eq__(self, other): - return isinstance(other, FlyteIdlEntity) and \ - other.to_flyte_idl() == self.to_flyte_idl() + return isinstance(other, FlyteIdlEntity) and other.to_flyte_idl() == self.to_flyte_idl() def __ne__(self, other): return not (self == other) @@ -80,7 +78,6 @@ def to_flyte_idl(self): class FlyteCustomIdlEntity(FlyteIdlEntity): - @classmethod def from_flyte_idl(cls, idl_object): """ @@ -107,7 +104,6 @@ def to_dict(self): class NamedEntityIdentifier(FlyteIdlEntity): - def __init__(self, project, domain, name=None): """ :param Text project: The name of the project in which this entity lives. @@ -149,11 +145,7 @@ def to_flyte_idl(self): """ # We use the kwarg constructor of the protobuf and setting name=None is equivalent to not setting it at all - return _common_pb2.NamedEntityIdentifier( - project=self.project, - domain=self.domain, - name=self.name - ) + return _common_pb2.NamedEntityIdentifier(project=self.project, domain=self.domain, name=self.name) @classmethod def from_flyte_idl(cls, idl_object): @@ -165,7 +157,6 @@ def from_flyte_idl(cls, idl_object): class EmailNotification(FlyteIdlEntity): - def __init__(self, recipients_email): """ :param list[Text] recipients_email: @@ -195,7 +186,6 @@ def from_flyte_idl(cls, pb2_object): class SlackNotification(FlyteIdlEntity): - def __init__(self, recipients_email): """ :param list[Text] recipients_email: @@ -225,7 +215,6 @@ def from_flyte_idl(cls, pb2_object): class PagerDutyNotification(FlyteIdlEntity): - def __init__(self, recipients_email): """ :param list[Text] recipients_email: @@ -255,7 +244,6 @@ def from_flyte_idl(cls, pb2_object): class Notification(FlyteIdlEntity): - def __init__(self, phases, email=None, pager_duty=None, slack=None): """ Represents a structure for notifications based on execution status. @@ -307,7 +295,7 @@ def to_flyte_idl(self): phases=self.phases, email=self.email.to_flyte_idl() if self.email else None, pager_duty=self.pager_duty.to_flyte_idl() if self.pager_duty else None, - slack=self.slack.to_flyte_idl() if self.slack else None + slack=self.slack.to_flyte_idl() if self.slack else None, ) @classmethod @@ -341,9 +329,7 @@ def to_flyte_idl(self): """ :rtype: dict[Text, Text] """ - return _common_pb2.Labels( - values={k: v for k, v in _six.iteritems(self.values)} - ) + return _common_pb2.Labels(values={k: v for k, v in _six.iteritems(self.values)}) @classmethod def from_flyte_idl(cls, pb2_object): @@ -371,9 +357,7 @@ def to_flyte_idl(self): """ :rtype: _common_pb2.Annotations """ - return _common_pb2.Annotations( - values={k: v for k, v in _six.iteritems(self.values)} - ) + return _common_pb2.Annotations(values={k: v for k, v in _six.iteritems(self.values)}) @classmethod def from_flyte_idl(cls, pb2_object): @@ -385,7 +369,6 @@ def from_flyte_idl(cls, pb2_object): class UrlBlob(FlyteIdlEntity): - def __init__(self, url, bytes): """ :param Text url: @@ -469,13 +452,13 @@ def from_flyte_idl(cls, pb2_object): """ return cls( assumable_iam_role=pb2_object.assumable_iam_role if pb2_object.HasField("assumable_iam_role") else None, - kubernetes_service_account=pb2_object.kubernetes_service_account if - pb2_object.HasField("kubernetes_service_account") else None, + kubernetes_service_account=pb2_object.kubernetes_service_account + if pb2_object.HasField("kubernetes_service_account") + else None, ) class RawOutputDataConfig(FlyteIdlEntity): - def __init__(self, output_location_prefix): """ :param Text output_location_prefix: Location of offloaded data for things like S3, etc. @@ -490,12 +473,8 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.common_pb2.Auth """ - return _common_pb2.RawOutputDataConfig( - output_location_prefix=self.output_location_prefix - ) + return _common_pb2.RawOutputDataConfig(output_location_prefix=self.output_location_prefix) @classmethod def from_flyte_idl(cls, pb2): - return cls( - output_location_prefix=pb2.output_location_prefix - ) + return cls(output_location_prefix=pb2.output_location_prefix) diff --git a/flytekit/models/core/compiler.py b/flytekit/models/core/compiler.py index 33acb37c42..f1e56ff7d8 100644 --- a/flytekit/models/core/compiler.py +++ b/flytekit/models/core/compiler.py @@ -1,14 +1,13 @@ from __future__ import absolute_import +import six as _six from flyteidl.core import compiler_pb2 as _compiler_pb2 + from flytekit.models import common as _common from flytekit.models.core import workflow as _core_workflow_models -import six as _six - class ConnectionSet(_common.FlyteIdlEntity): - class IdList(_common.FlyteIdlEntity): def __init__(self, ids): """ @@ -65,7 +64,7 @@ def to_flyte_idl(self): """ return _compiler_pb2.ConnectionSet( upstream={k: v.to_flyte_idl() for k, v in _six.iteritems(self.upstream)}, - downstream={k: v.to_flyte_idl() for k, v in _six.iteritems(self.upstream)} + downstream={k: v.to_flyte_idl() for k, v in _six.iteritems(self.upstream)}, ) @classmethod @@ -76,12 +75,11 @@ def from_flyte_idl(cls, p): """ return cls( upstream={k: ConnectionSet.IdList.from_flyte_idl(v) for k, v in _six.iteritems(p.upstream)}, - downstream={k: ConnectionSet.IdList.from_flyte_idl(v) for k, v in _six.iteritems(p.downstream)} + downstream={k: ConnectionSet.IdList.from_flyte_idl(v) for k, v in _six.iteritems(p.downstream)}, ) class CompiledWorkflow(_common.FlyteIdlEntity): - def __init__(self, template, connections): """ :param flytekit.models.core.workflow.WorkflowTemplate template: @@ -109,8 +107,7 @@ def to_flyte_idl(self): :rtype: flyteidl.core.compiler_pb2.CompiledWorkflow """ return _compiler_pb2.CompiledWorkflow( - template=self.template.to_flyte_idl(), - connections=self.connections.to_flyte_idl() + template=self.template.to_flyte_idl(), connections=self.connections.to_flyte_idl(), ) @classmethod @@ -121,13 +118,12 @@ def from_flyte_idl(cls, p): """ return cls( template=_core_workflow_models.WorkflowTemplate.from_flyte_idl(p.template), - connections=ConnectionSet.from_flyte_idl(p.connections) + connections=ConnectionSet.from_flyte_idl(p.connections), ) # TODO: properly sort out the model code and remove one of these duplicate CompiledTasks class CompiledTask(_common.FlyteIdlEntity): - def __init__(self, template): """ :param TODO template: @@ -145,9 +141,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.compiler_pb2.CompiledTask """ - return _compiler_pb2.CompiledTask( - template=self.template # TODO: .to_flyte_idl() - ) + return _compiler_pb2.CompiledTask(template=self.template) # TODO: .to_flyte_idl() @classmethod def from_flyte_idl(cls, p): @@ -160,7 +154,6 @@ def from_flyte_idl(cls, p): class CompiledWorkflowClosure(_common.FlyteIdlEntity): - def __init__(self, primary, sub_workflows, tasks): """ :param CompiledWorkflow primary: @@ -199,7 +192,7 @@ def to_flyte_idl(self): return _compiler_pb2.CompiledWorkflowClosure( primary=self.primary.to_flyte_idl(), sub_workflows=[s.to_flyte_idl() for s in self.sub_workflows], - tasks=[t.to_flyte_idl() for t in self.tasks] + tasks=[t.to_flyte_idl() for t in self.tasks], ) @classmethod @@ -211,8 +204,9 @@ def from_flyte_idl(cls, p): # This import is here to prevent a circular dependency issue. # TODO: properly sort out the model code and remove the duplicate CompiledTask from flytekit.models.task import CompiledTask as _CompiledTask + return cls( primary=CompiledWorkflow.from_flyte_idl(p.primary), sub_workflows=[CompiledWorkflow.from_flyte_idl(s) for s in p.sub_workflows], - tasks=[_CompiledTask.from_flyte_idl(t) for t in p.tasks] + tasks=[_CompiledTask.from_flyte_idl(t) for t in p.tasks], ) diff --git a/flytekit/models/core/condition.py b/flytekit/models/core/condition.py index 494176a999..d51e258113 100644 --- a/flytekit/models/core/condition.py +++ b/flytekit/models/core/condition.py @@ -2,7 +2,8 @@ from flyteidl.core import condition_pb2 as _condition -from flytekit.models import common as _common, literals as _literals +from flytekit.models import common as _common +from flytekit.models import literals as _literals class ComparisonExpression(_common.FlyteIdlEntity): @@ -59,15 +60,19 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.condition_pb2.ComparisonExpression """ - return _condition.ComparisonExpression(operator=self.operator, - left_value=self.left_value.to_flyte_idl(), - right_value=self.right_value.to_flyte_idl()) + return _condition.ComparisonExpression( + operator=self.operator, + left_value=self.left_value.to_flyte_idl(), + right_value=self.right_value.to_flyte_idl(), + ) @classmethod def from_flyte_idl(cls, pb2_object): - return cls(operator=pb2_object.operator, - left_value=Operand.from_flyte_idl(pb2_object.left_value), - right_value=Operand.from_flyte_idl(pb2_object.right_value)) + return cls( + operator=pb2_object.operator, + left_value=Operand.from_flyte_idl(pb2_object.left_value), + right_value=Operand.from_flyte_idl(pb2_object.right_value), + ) class ConjunctionExpression(_common.FlyteIdlEntity): @@ -115,15 +120,19 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.condition_pb2.ConjunctionExpression """ - return _condition.ConjunctionExpression(operator=self.operator, - left_expression=self.left_expression.to_flyte_idl(), - right_expression=self.right_expression.to_flyte_idl()) + return _condition.ConjunctionExpression( + operator=self.operator, + left_expression=self.left_expression.to_flyte_idl(), + right_expression=self.right_expression.to_flyte_idl(), + ) @classmethod def from_flyte_idl(cls, pb2_object): - return cls(operator=pb2_object.operator, - left_expression=BooleanExpression.from_flyte_idl(pb2_object.left_expression), - right_expression=BooleanExpression.from_flyte_idl(pb2_object.right_expression)) + return cls( + operator=pb2_object.operator, + left_expression=BooleanExpression.from_flyte_idl(pb2_object.left_expression), + right_expression=BooleanExpression.from_flyte_idl(pb2_object.right_expression), + ) class Operand(_common.FlyteIdlEntity): @@ -157,14 +166,18 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.condition_pb2.Operand """ - return _condition.Operand(primitive=self.primitive.to_flyte_idl() if self.primitive else None, - var=self.var if self.var else None) + return _condition.Operand( + primitive=self.primitive.to_flyte_idl() if self.primitive else None, var=self.var if self.var else None, + ) @classmethod def from_flyte_idl(cls, pb2_object): - return cls(primitive=_literals.Primitive.from_flyte_idl(pb2_object.primitive) if pb2_object.HasField( - 'primitive') else None, - var=pb2_object.var if pb2_object.HasField('var') else None) + return cls( + primitive=_literals.Primitive.from_flyte_idl(pb2_object.primitive) + if pb2_object.HasField("primitive") + else None, + var=pb2_object.var if pb2_object.HasField("var") else None, + ) class BooleanExpression(_common.FlyteIdlEntity): @@ -200,12 +213,18 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.condition_pb2.BooleanExpression """ - return _condition.BooleanExpression(conjunction=self.conjunction.to_flyte_idl() if self.conjunction else None, - comparison=self.comparison.to_flyte_idl() if self.comparison else None) + return _condition.BooleanExpression( + conjunction=self.conjunction.to_flyte_idl() if self.conjunction else None, + comparison=self.comparison.to_flyte_idl() if self.comparison else None, + ) @classmethod def from_flyte_idl(cls, pb2_object): - return cls(conjunction=ConjunctionExpression.from_flyte_idl( - pb2_object.conjunction) if pb2_object.HasField('conjunction') else None, - comparison=ComparisonExpression.from_flyte_idl( - pb2_object.comparison) if pb2_object.HasField('comparison') else None) + return cls( + conjunction=ConjunctionExpression.from_flyte_idl(pb2_object.conjunction) + if pb2_object.HasField("conjunction") + else None, + comparison=ComparisonExpression.from_flyte_idl(pb2_object.comparison) + if pb2_object.HasField("comparison") + else None, + ) diff --git a/flytekit/models/core/errors.py b/flytekit/models/core/errors.py index f13cb5703a..d109d8382e 100644 --- a/flytekit/models/core/errors.py +++ b/flytekit/models/core/errors.py @@ -1,10 +1,11 @@ from __future__ import absolute_import -from flytekit.models import common as _common + from flyteidl.core import errors_pb2 as _errors_pb2 +from flytekit.models import common as _common + class ContainerError(_common.FlyteIdlEntity): - class Kind(object): NON_RECOVERABLE = _errors_pb2.ContainerError.NON_RECOVERABLE RECOVERABLE = _errors_pb2.ContainerError.RECOVERABLE @@ -56,7 +57,6 @@ def from_flyte_idl(cls, proto): class ErrorDocument(_common.FlyteIdlEntity): - def __init__(self, error): """ :param ContainerError error: diff --git a/flytekit/models/core/execution.py b/flytekit/models/core/execution.py index 8da1ea0e1e..d16a649b5d 100644 --- a/flytekit/models/core/execution.py +++ b/flytekit/models/core/execution.py @@ -1,5 +1,7 @@ from __future__ import absolute_import + from flyteidl.core import execution_pb2 as _execution_pb2 + from flytekit.models import common as _common @@ -112,7 +114,6 @@ def enum_to_string(cls, int_value): class ExecutionError(_common.FlyteIdlEntity): - def __init__(self, code, message, error_uri): """ :param Text code: @@ -148,11 +149,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.execution_pb2.ExecutionError """ - return _execution_pb2.ExecutionError( - code=self.code, - message=self.message, - error_uri=self.error_uri, - ) + return _execution_pb2.ExecutionError(code=self.code, message=self.message, error_uri=self.error_uri,) @classmethod def from_flyte_idl(cls, p): @@ -160,15 +157,10 @@ def from_flyte_idl(cls, p): :param flyteidl.core.execution_pb2.ExecutionError p: :rtype: ExecutionError """ - return cls( - code=p.code, - message=p.message, - error_uri=p.error_uri, - ) + return cls(code=p.code, message=p.message, error_uri=p.error_uri,) class TaskLog(_common.FlyteIdlEntity): - class MessageFormat(object): UNKNOWN = _execution_pb2.TaskLog.UNKNOWN CSV = _execution_pb2.TaskLog.CSV @@ -219,11 +211,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.execution_pb2.TaskLog """ - p = _execution_pb2.TaskLog( - uri=self.uri, - name=self.name, - message_format=self.message_format - ) + p = _execution_pb2.TaskLog(uri=self.uri, name=self.name, message_format=self.message_format) p.ttl.FromTimedelta(self.ttl) return p @@ -233,9 +221,4 @@ def from_flyte_idl(cls, p): :param flyteidl.core.execution_pb2.TaskLog p: :rtype: TaskLog """ - return cls( - uri=p.uri, - name=p.name, - message_format=p.message_format, - ttl=p.ttl.ToTimedelta() - ) + return cls(uri=p.uri, name=p.name, message_format=p.message_format, ttl=p.ttl.ToTimedelta(),) diff --git a/flytekit/models/core/identifier.py b/flytekit/models/core/identifier.py index a36a13dfb2..1b0b0d4e17 100644 --- a/flytekit/models/core/identifier.py +++ b/flytekit/models/core/identifier.py @@ -1,7 +1,9 @@ from __future__ import absolute_import -from flytekit.models import common as _common_models + from flyteidl.core import identifier_pb2 as _identifier_pb2 +from flytekit.models import common as _common_models + class ResourceType(object): UNSPECIFIED = _identifier_pb2.UNSPECIFIED @@ -11,7 +13,6 @@ class ResourceType(object): class Identifier(_common_models.FlyteIdlEntity): - def __init__(self, resource_type, project, domain, name, version): """ :param int resource_type: enum value from ResourceType @@ -71,7 +72,7 @@ def to_flyte_idl(self): project=self.project, domain=self.domain, name=self.name, - version=self.version + version=self.version, ) @classmethod @@ -80,13 +81,7 @@ def from_flyte_idl(cls, p): :param flyteidl.core.identifier_pb2.Identifier p: :rtype: Identifier """ - return cls( - resource_type=p.resource_type, - project=p.project, - domain=p.domain, - name=p.name, - version=p.version - ) + return cls(resource_type=p.resource_type, project=p.project, domain=p.domain, name=p.name, version=p.version,) class WorkflowExecutionIdentifier(_common_models.FlyteIdlEntity): @@ -125,11 +120,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.identifier_pb2.WorkflowExecutionIdentifier """ - return _identifier_pb2.WorkflowExecutionIdentifier( - project=self.project, - domain=self.domain, - name=self.name, - ) + return _identifier_pb2.WorkflowExecutionIdentifier(project=self.project, domain=self.domain, name=self.name,) @classmethod def from_flyte_idl(cls, p): @@ -137,15 +128,10 @@ def from_flyte_idl(cls, p): :param flyteidl.core.identifier_pb2.WorkflowExecutionIdentifier p: :rtype: WorkflowExecutionIdentifier """ - return cls( - project=p.project, - domain=p.domain, - name=p.name, - ) + return cls(project=p.project, domain=p.domain, name=p.name,) class NodeExecutionIdentifier(_common_models.FlyteIdlEntity): - def __init__(self, node_id, execution_id): """ :param Text node_id: @@ -173,8 +159,7 @@ def to_flyte_idl(self): :rtype: flyteidl.core.identifier_pb2.NodeExecutionIdentifier """ return _identifier_pb2.NodeExecutionIdentifier( - node_id=self.node_id, - execution_id=self.execution_id.to_flyte_idl(), + node_id=self.node_id, execution_id=self.execution_id.to_flyte_idl(), ) @classmethod @@ -183,14 +168,10 @@ def from_flyte_idl(cls, p): :param flyteidl.core.identifier_pb2.NodeExecutionIdentifier p: :rtype: NodeExecutionIdentifier """ - return cls( - node_id=p.node_id, - execution_id=WorkflowExecutionIdentifier.from_flyte_idl(p.execution_id), - ) + return cls(node_id=p.node_id, execution_id=WorkflowExecutionIdentifier.from_flyte_idl(p.execution_id),) class TaskExecutionIdentifier(_common_models.FlyteIdlEntity): - def __init__(self, task_id, node_execution_id, retry_attempt): """ :param Identifier task_id: The identifier for the task that is executing @@ -229,7 +210,7 @@ def to_flyte_idl(self): return _identifier_pb2.TaskExecutionIdentifier( task_id=self.task_id.to_flyte_idl(), node_execution_id=self.node_execution_id.to_flyte_idl(), - retry_attempt=self.retry_attempt + retry_attempt=self.retry_attempt, ) @classmethod @@ -241,5 +222,5 @@ def from_flyte_idl(cls, proto): return cls( task_id=Identifier.from_flyte_idl(proto.task_id), node_execution_id=NodeExecutionIdentifier.from_flyte_idl(proto.node_execution_id), - retry_attempt=proto.retry_attempt + retry_attempt=proto.retry_attempt, ) diff --git a/flytekit/models/core/types.py b/flytekit/models/core/types.py index a578453dc5..fb0ec4090e 100644 --- a/flytekit/models/core/types.py +++ b/flytekit/models/core/types.py @@ -38,10 +38,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.types_pb2.BlobType """ - return _types_pb2.BlobType( - format=self.format, - dimensionality=self.dimensionality - ) + return _types_pb2.BlobType(format=self.format, dimensionality=self.dimensionality) @classmethod def from_flyte_idl(cls, proto): @@ -49,7 +46,4 @@ def from_flyte_idl(cls, proto): :param flyteidl.core.types_pb2.BlobType proto: :rtype: BlobType """ - return cls( - format=proto.format, - dimensionality=proto.dimensionality - ) + return cls(format=proto.format, dimensionality=proto.dimensionality) diff --git a/flytekit/models/core/workflow.py b/flytekit/models/core/workflow.py index fdf12c694f..2425cb1fbf 100644 --- a/flytekit/models/core/workflow.py +++ b/flytekit/models/core/workflow.py @@ -2,9 +2,13 @@ from flyteidl.core import workflow_pb2 as _core_workflow -from flytekit.models import common as _common, interface as _interface -from flytekit.models.core import identifier as _identifier, errors as _errors, condition as _condition -from flytekit.models.literals import RetryStrategy as _RetryStrategy, Binding as _Binding +from flytekit.models import common as _common +from flytekit.models import interface as _interface +from flytekit.models.core import condition as _condition +from flytekit.models.core import errors as _errors +from flytekit.models.core import identifier as _identifier +from flytekit.models.literals import Binding as _Binding +from flytekit.models.literals import RetryStrategy as _RetryStrategy class IfBlock(_common.FlyteIdlEntity): @@ -36,13 +40,14 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.workflow_pb2.IfBlock """ - return _core_workflow.IfBlock(condition=self.condition.to_flyte_idl(), - then_node=self.then_node.to_flyte_idl()) + return _core_workflow.IfBlock(condition=self.condition.to_flyte_idl(), then_node=self.then_node.to_flyte_idl(),) @classmethod def from_flyte_idl(cls, pb2_object): - return cls(condition=_condition.BooleanExpression.from_flyte_idl(pb2_object.condition), - then_node=Node.from_flyte_idl(pb2_object.then_node)) + return cls( + condition=_condition.BooleanExpression.from_flyte_idl(pb2_object.condition), + then_node=Node.from_flyte_idl(pb2_object.then_node), + ) class IfElseBlock(_common.FlyteIdlEntity): @@ -102,18 +107,20 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.workflow_pb2.IfElseBlock """ - return _core_workflow.IfElseBlock(case=self.case.to_flyte_idl(), - other=[a.to_flyte_idl() for a in self.other] if self.other else None, - else_node=self.else_node.to_flyte_idl() if self.else_node else None, - error=self.error.to_flyte_idl() if self.error else None) + return _core_workflow.IfElseBlock( + case=self.case.to_flyte_idl(), + other=[a.to_flyte_idl() for a in self.other] if self.other else None, + else_node=self.else_node.to_flyte_idl() if self.else_node else None, + error=self.error.to_flyte_idl() if self.error else None, + ) @classmethod def from_flyte_idl(cls, pb2_object): return cls( case=IfBlock.from_flyte_idl(pb2_object.case), other=[IfBlock.from_flyte_idl(a) for a in pb2_object.other], - else_node=Node.from_flyte_idl(pb2_object.else_node) if pb2_object.HasField('else_node') else None, - error=_errors.ContainerError.from_flyte_idl(pb2_object.error) if pb2_object.HasField('error') else None + else_node=Node.from_flyte_idl(pb2_object.else_node) if pb2_object.HasField("else_node") else None, + error=_errors.ContainerError.from_flyte_idl(pb2_object.error) if pb2_object.HasField("error") else None, ) @@ -147,7 +154,6 @@ def from_flyte_idl(cls, pb2_objct): class NodeMetadata(_common.FlyteIdlEntity): - def __init__(self, name, timeout, retries, interruptible=False): """ Defines extra information about the Node. @@ -193,23 +199,31 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.workflow_pb2.NodeMetadata """ - node_metadata = _core_workflow.NodeMetadata(name=self.name, retries=self.retries.to_flyte_idl(), interruptible=self.interruptible) + node_metadata = _core_workflow.NodeMetadata( + name=self.name, retries=self.retries.to_flyte_idl(), interruptible=self.interruptible, + ) node_metadata.timeout.FromTimedelta(self.timeout) return node_metadata @classmethod def from_flyte_idl(cls, pb2_object): return cls( - pb2_object.name, - pb2_object.timeout.ToTimedelta(), - _RetryStrategy.from_flyte_idl(pb2_object.retries) + pb2_object.name, pb2_object.timeout.ToTimedelta(), _RetryStrategy.from_flyte_idl(pb2_object.retries), ) class Node(_common.FlyteIdlEntity): - - def __init__(self, id, metadata, inputs, upstream_node_ids, output_aliases, task_node=None, - workflow_node=None, branch_node=None): + def __init__( + self, + id, + metadata, + inputs, + upstream_node_ids, + output_aliases, + task_node=None, + workflow_node=None, + branch_node=None, + ): """ A Workflow graph Node. One unit of execution in the graph. Each node can be linked to a Task, a Workflow or a branch node. One of the nodes must be specified. @@ -328,7 +342,7 @@ def to_flyte_idl(self): output_aliases=[a.to_flyte_idl() for a in self.output_aliases], task_node=self.task_node.to_flyte_idl() if self.task_node is not None else None, workflow_node=self.workflow_node.to_flyte_idl() if self.workflow_node is not None else None, - branch_node=self.branch_node.to_flyte_idl() if self.branch_node is not None else None + branch_node=self.branch_node.to_flyte_idl() if self.branch_node is not None else None, ) @classmethod @@ -343,16 +357,17 @@ def from_flyte_idl(cls, pb2_object): inputs=[_Binding.from_flyte_idl(b) for b in pb2_object.inputs], upstream_node_ids=pb2_object.upstream_node_ids, output_aliases=[Alias.from_flyte_idl(a) for a in pb2_object.output_aliases], - task_node=TaskNode.from_flyte_idl(pb2_object.task_node) if pb2_object.HasField('task_node') else None, + task_node=TaskNode.from_flyte_idl(pb2_object.task_node) if pb2_object.HasField("task_node") else None, workflow_node=WorkflowNode.from_flyte_idl(pb2_object.workflow_node) - if pb2_object.HasField('workflow_node') else None, - branch_node=BranchNode.from_flyte_idl(pb2_object.branch_node) if pb2_object.HasField( - 'branch_node') else None, + if pb2_object.HasField("workflow_node") + else None, + branch_node=BranchNode.from_flyte_idl(pb2_object.branch_node) + if pb2_object.HasField("branch_node") + else None, ) class TaskNode(_common.FlyteIdlEntity): - def __init__(self, reference_id): """ Refers to the task that the Node is to execute. @@ -387,7 +402,6 @@ def from_flyte_idl(cls, pb2_object): class WorkflowNode(_common.FlyteIdlEntity): - def __init__(self, launchplan_ref=None, sub_workflow_ref=None): """ Refers to a the workflow the node is to execute. One of the references must be supplied. @@ -429,7 +443,7 @@ def to_flyte_idl(self): """ return _core_workflow.WorkflowNode( launchplan_ref=self.launchplan_ref.to_flyte_idl() if self.launchplan_ref else None, - sub_workflow_ref=self.sub_workflow_ref.to_flyte_idl() if self.sub_workflow_ref else None + sub_workflow_ref=self.sub_workflow_ref.to_flyte_idl() if self.sub_workflow_ref else None, ) @classmethod @@ -438,14 +452,13 @@ def from_flyte_idl(cls, pb2_object): :param flyteidl.core.workflow_pb2.WorkflowNode pb2_object: :rtype: WorkflowNode """ - if pb2_object.HasField('launchplan_ref'): + if pb2_object.HasField("launchplan_ref"): return cls(launchplan_ref=_identifier.Identifier.from_flyte_idl(pb2_object.launchplan_ref)) else: return cls(sub_workflow_ref=_identifier.Identifier.from_flyte_idl(pb2_object.sub_workflow_ref)) class WorkflowMetadata(_common.FlyteIdlEntity): - class OnFailurePolicy(object): """ Defines the execution behavior of the workflow when a failure is detected. @@ -456,19 +469,19 @@ class OnFailurePolicy(object): clean up resources before finally marking the workflow executions as failed. FAIL_AFTER_EXECUTABLE_NODES_COMPLETE Instructs the system to make as much progress as it can. The system - will not alter the dependencies of the execution graph so any node + will not alter the dependencies of the execution graph so any node that depend on the failed node will not be run. Other nodes that will be executed to completion before cleaning up resources and marking the workflow execution as failed. """ - FAIL_IMMEDIATELY = _core_workflow.WorkflowMetadata.FAIL_IMMEDIATELY + FAIL_IMMEDIATELY = _core_workflow.WorkflowMetadata.FAIL_IMMEDIATELY FAIL_AFTER_EXECUTABLE_NODES_COMPLETE = _core_workflow.WorkflowMetadata.FAIL_AFTER_EXECUTABLE_NODES_COMPLETE def __init__(self, on_failure=None): """ Metadata for the workflow. - + :param on_failure flytekit.models.core.workflow.WorkflowMetadata.OnFailurePolicy: [Optional] The execution policy when the workflow detects a failure. """ self._on_failure = on_failure @@ -496,12 +509,13 @@ def from_flyte_idl(cls, pb2_object): :rtype: WorkflowMetadata """ return cls( - on_failure=pb2_object.on_failure if pb2_object.on_failure else WorkflowMetadata.OnFailurePolicy.FAIL_IMMEDIATELY + on_failure=pb2_object.on_failure + if pb2_object.on_failure + else WorkflowMetadata.OnFailurePolicy.FAIL_IMMEDIATELY ) class WorkflowMetadataDefaults(_common.FlyteIdlEntity): - def __init__(self, interruptible=None): """ Metadata Defaults for the workflow. @@ -512,9 +526,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.workflow_pb2.WorkflowMetadataDefaults """ - return _core_workflow.WorkflowMetadataDefaults( - interruptible=self.interruptible_ - ) + return _core_workflow.WorkflowMetadataDefaults(interruptible=self.interruptible_) @classmethod def from_flyte_idl(cls, pb2_object): @@ -526,8 +538,9 @@ def from_flyte_idl(cls, pb2_object): class WorkflowTemplate(_common.FlyteIdlEntity): - - def __init__(self, id, metadata, metadata_defaults, interface, nodes, outputs, failure_node=None): + def __init__( + self, id, metadata, metadata_defaults, interface, nodes, outputs, failure_node=None, + ): """ A workflow template encapsulates all the task, branch, and subworkflow nodes to run a statically analyzable, directed acyclic graph. It contains also metadata that tells the system how to execute the workflow (i.e. @@ -632,7 +645,7 @@ def to_flyte_idl(self): interface=self.interface.to_flyte_idl(), nodes=[n.to_flyte_idl() for n in self.nodes], outputs=[o.to_flyte_idl() for o in self.outputs], - failure_node=self.failure_node.to_flyte_idl() if self.failure_node is not None else None + failure_node=self.failure_node.to_flyte_idl() if self.failure_node is not None else None, ) @classmethod @@ -648,12 +661,11 @@ def from_flyte_idl(cls, pb2_object): interface=_interface.TypedInterface.from_flyte_idl(pb2_object.interface), nodes=[Node.from_flyte_idl(n) for n in pb2_object.nodes], outputs=[_Binding.from_flyte_idl(b) for b in pb2_object.outputs], - failure_node=Node.from_flyte_idl(pb2_object.failure_node) if pb2_object.HasField('failure_node') else None + failure_node=Node.from_flyte_idl(pb2_object.failure_node) if pb2_object.HasField("failure_node") else None, ) class Alias(_common.FlyteIdlEntity): - def __init__(self, var, alias): """ Links a variable to an alias. diff --git a/flytekit/models/dynamic_job.py b/flytekit/models/dynamic_job.py index a3ba95ae38..7018e3b664 100644 --- a/flytekit/models/dynamic_job.py +++ b/flytekit/models/dynamic_job.py @@ -2,7 +2,9 @@ from flyteidl.core import dynamic_job_pb2 as _dynamic_job -from flytekit.models import common as _common, task as _task, literals as _literals +from flytekit.models import common as _common +from flytekit.models import literals as _literals +from flytekit.models import task as _task from flytekit.models.core import workflow as _workflow @@ -79,7 +81,7 @@ def to_flyte_idl(self): nodes=[node.to_flyte_idl() for node in self.nodes] if self.nodes else None, min_successes=self.min_successes, outputs=[output.to_flyte_idl() for output in self.outputs], - subworkflows=[workflow.to_flyte_idl() for workflow in self.subworkflows] + subworkflows=[workflow.to_flyte_idl() for workflow in self.subworkflows], ) @classmethod @@ -92,7 +94,8 @@ def from_flyte_idl(cls, pb2_object): tasks=[_task.TaskTemplate.from_flyte_idl(task) for task in pb2_object.tasks] if pb2_object.tasks else None, nodes=[_workflow.Node.from_flyte_idl(n) for n in pb2_object.nodes], min_successes=pb2_object.min_successes, - outputs=[_literals.Binding.from_flyte_idl(output) for output in - pb2_object.outputs] if pb2_object.outputs else None, - subworkflows=[_workflow.WorkflowTemplate.from_flyte_idl(w) for w in pb2_object.subworkflows] + outputs=[_literals.Binding.from_flyte_idl(output) for output in pb2_object.outputs] + if pb2_object.outputs + else None, + subworkflows=[_workflow.WorkflowTemplate.from_flyte_idl(w) for w in pb2_object.subworkflows], ) diff --git a/flytekit/models/execution.py b/flytekit/models/execution.py index e6fed7abb6..2f07d37b28 100644 --- a/flytekit/models/execution.py +++ b/flytekit/models/execution.py @@ -3,13 +3,14 @@ import flyteidl.admin.execution_pb2 as _execution_pb2 import flyteidl.admin.node_execution_pb2 as _node_execution_pb2 import flyteidl.admin.task_execution_pb2 as _task_execution_pb2 +import pytz as _pytz from flytekit.models import common as _common_models -from flytekit.models.core import execution as _core_execution, identifier as _identifier -import pytz as _pytz +from flytekit.models.core import execution as _core_execution +from flytekit.models.core import identifier as _identifier -class ExecutionMetadata(_common_models.FlyteIdlEntity): +class ExecutionMetadata(_common_models.FlyteIdlEntity): class ExecutionMode(object): MANUAL = 0 SCHEDULED = 1 @@ -54,11 +55,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.execution_pb2.ExecutionMetadata """ - return _execution_pb2.ExecutionMetadata( - mode=self.mode, - principal=self.principal, - nesting=self.nesting - ) + return _execution_pb2.ExecutionMetadata(mode=self.mode, principal=self.principal, nesting=self.nesting) @classmethod def from_flyte_idl(cls, pb2_object): @@ -66,17 +63,20 @@ def from_flyte_idl(cls, pb2_object): :param flyteidl.admin.execution_pb2.ExecutionMetadata pb2_object: :return: ExecutionMetadata """ - return cls( - mode=pb2_object.mode, - principal=pb2_object.principal, - nesting=pb2_object.nesting - ) + return cls(mode=pb2_object.mode, principal=pb2_object.principal, nesting=pb2_object.nesting,) class ExecutionSpec(_common_models.FlyteIdlEntity): - - def __init__(self, launch_plan, metadata, notifications=None, disable_all=None, labels=None, - annotations=None, auth_role=None): + def __init__( + self, + launch_plan, + metadata, + notifications=None, + disable_all=None, + labels=None, + annotations=None, + auth_role=None, + ): """ :param flytekit.models.core.identifier.Identifier launch_plan: Launch plan unique identifier to execute :param ExecutionMetadata metadata: The metadata to be associated with this execution @@ -177,7 +177,6 @@ def from_flyte_idl(cls, p): class LiteralMapBlob(_common_models.FlyteIdlEntity): - def __init__(self, values=None, uri=None): """ :param flytekit.models.literals.LiteralMap values: @@ -205,8 +204,7 @@ def to_flyte_idl(self): :rtype: flyteidl.admin.execution_pb2.LiteralMapBlob """ return _execution_pb2.LiteralMapBlob( - values=self.values.to_flyte_idl() if self.values is not None else None, - uri=self.uri + values=self.values.to_flyte_idl() if self.values is not None else None, uri=self.uri, ) @classmethod @@ -218,14 +216,10 @@ def from_flyte_idl(cls, pb): values = None if pb.HasField("values"): values = LiteralMapBlob.from_flyte_idl(pb.values) - return cls( - values=values, - uri=pb.uri if pb.HasField("uri") else None - ) + return cls(values=values, uri=pb.uri if pb.HasField("uri") else None) class Execution(_common_models.FlyteIdlEntity): - def __init__(self, id, spec, closure): """ :param flytekit.models.core.identifier.WorkflowExecutionIdentifier id: @@ -263,9 +257,7 @@ def to_flyte_idl(self): :rtype: flyteidl.admin.execution_pb2.Execution """ return _execution_pb2.Execution( - id=self.id.to_flyte_idl(), - closure=self.closure.to_flyte_idl(), - spec=self.spec.to_flyte_idl() + id=self.id.to_flyte_idl(), closure=self.closure.to_flyte_idl(), spec=self.spec.to_flyte_idl(), ) @classmethod @@ -277,12 +269,11 @@ def from_flyte_idl(cls, pb): return cls( id=_identifier.WorkflowExecutionIdentifier.from_flyte_idl(pb.id), closure=ExecutionClosure.from_flyte_idl(pb.closure), - spec=ExecutionSpec.from_flyte_idl(pb.spec) + spec=ExecutionSpec.from_flyte_idl(pb.spec), ) class ExecutionClosure(_common_models.FlyteIdlEntity): - def __init__(self, phase, started_at, error=None, outputs=None): """ :param int phase: From the flytekit.models.core.execution.WorkflowExecutionPhase enum @@ -331,7 +322,7 @@ def to_flyte_idl(self): obj = _execution_pb2.ExecutionClosure( phase=self.phase, error=self.error.to_flyte_idl() if self.error is not None else None, - outputs=self.outputs.to_flyte_idl() if self.outputs is not None else None + outputs=self.outputs.to_flyte_idl() if self.outputs is not None else None, ) obj.started_at.FromDatetime(self.started_at.astimezone(_pytz.UTC).replace(tzinfo=None)) return obj @@ -374,9 +365,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.execution_pb2.NotificationList """ - return _execution_pb2.NotificationList( - notifications=[n.to_flyte_idl() for n in self.notifications] - ) + return _execution_pb2.NotificationList(notifications=[n.to_flyte_idl() for n in self.notifications]) @classmethod def from_flyte_idl(cls, pb2_object): @@ -384,9 +373,7 @@ def from_flyte_idl(cls, pb2_object): :param flyteidl.admin.execution_pb2.NotificationList pb2_object: :rtype: NotificationList """ - return cls( - [_common_models.Notification.from_flyte_idl(p) for p in pb2_object.notifications] - ) + return cls([_common_models.Notification.from_flyte_idl(p) for p in pb2_object.notifications]) class _CommonDataResponse(_common_models.FlyteIdlEntity): @@ -435,8 +422,7 @@ def to_flyte_idl(self): :rtype: _execution_pb2.WorkflowExecutionGetDataResponse """ return _execution_pb2.WorkflowExecutionGetDataResponse( - inputs=self.inputs.to_flyte_idl(), - outputs=self.outputs.to_flyte_idl(), + inputs=self.inputs.to_flyte_idl(), outputs=self.outputs.to_flyte_idl(), ) @@ -457,8 +443,7 @@ def to_flyte_idl(self): :rtype: _task_execution_pb2.TaskExecutionGetDataResponse """ return _task_execution_pb2.TaskExecutionGetDataResponse( - inputs=self.inputs.to_flyte_idl(), - outputs=self.outputs.to_flyte_idl(), + inputs=self.inputs.to_flyte_idl(), outputs=self.outputs.to_flyte_idl(), ) @@ -479,6 +464,5 @@ def to_flyte_idl(self): :rtype: _node_execution_pb2.NodeExecutionGetDataResponse """ return _node_execution_pb2.NodeExecutionGetDataResponse( - inputs=self.inputs.to_flyte_idl(), - outputs=self.outputs.to_flyte_idl(), + inputs=self.inputs.to_flyte_idl(), outputs=self.outputs.to_flyte_idl(), ) diff --git a/flytekit/models/filters.py b/flytekit/models/filters.py index 7a7267252d..749671406d 100644 --- a/flytekit/models/filters.py +++ b/flytekit/models/filters.py @@ -1,9 +1,9 @@ from __future__ import absolute_import + from flytekit.models.common import FlyteIdlEntity as _FlyteIdlEntity class FilterList(_FlyteIdlEntity): - def __init__(self, filter_list): """ :param list[Filter] filter_list: List of filters to AND together @@ -51,21 +51,21 @@ def from_python_std(cls, string): :param Text string: :rtype: Filter """ - if string.startswith('eq('): + if string.startswith("eq("): return Equal._parse_from_string(string) - elif string.startswith('neq('): + elif string.startswith("neq("): return NotEqual._parse_from_string(string) - elif string.startswith('gt('): + elif string.startswith("gt("): return GreaterThan._parse_from_string(string) - elif string.startswith('gte('): + elif string.startswith("gte("): return GreaterThanOrEqual._parse_from_string(string) - elif string.startswith('lt('): + elif string.startswith("lt("): return LessThan._parse_from_string(string) - elif string.startswith('lte('): + elif string.startswith("lte("): return LessThanOrEqual._parse_from_string(string) - elif string.startswith('contains('): + elif string.startswith("contains("): return Contains._parse_from_string(string) - elif string.startswith('value_in('): + elif string.startswith("value_in("): return ValueIn._parse_from_string(string) else: raise ValueError("'{}' could not be parsed into a filter.".format(string)) @@ -76,8 +76,8 @@ def _parse_from_string(cls, string): :param Text string: :rtype: Filter """ - stripped = string[len(cls._comparator) + 1:] - if stripped[-1] != ')': + stripped = string[len(cls._comparator) + 1 :] + if stripped[-1] != ")": raise ValueError("Filter could not be parsed because {} did not end with a ')'".format(string)) split = stripped[:-1].split(",") if len(split) != 2: @@ -92,31 +92,30 @@ def _parse_value(cls, value): class Equal(Filter): - _comparator = 'eq' + _comparator = "eq" class NotEqual(Filter): - _comparator = 'neq' + _comparator = "neq" class GreaterThan(Filter): - _comparator = 'gt' + _comparator = "gt" class GreaterThanOrEqual(Filter): - _comparator = 'gte' + _comparator = "gte" class LessThan(Filter): - _comparator = 'lt' + _comparator = "lt" class LessThanOrEqual(Filter): - _comparator = 'lte' + _comparator = "lte" class SetFilter(Filter): - def __init__(self, key, values): """ :param Text key: The name of the field to compare against @@ -126,12 +125,12 @@ def __init__(self, key, values): @classmethod def _parse_value(cls, value): - return value.split(';') + return value.split(";") class Contains(SetFilter): - _comparator = 'contains' + _comparator = "contains" class ValueIn(SetFilter): - _comparator = 'value_in' + _comparator = "value_in" diff --git a/flytekit/models/interface.py b/flytekit/models/interface.py index cc96071449..3e2cab71c1 100644 --- a/flytekit/models/interface.py +++ b/flytekit/models/interface.py @@ -1,12 +1,14 @@ from __future__ import absolute_import -from flytekit.models import common as _common, types as _types, literals as _literals -from flyteidl.core import interface_pb2 as _interface_pb2 import six as _six +from flyteidl.core import interface_pb2 as _interface_pb2 +from flytekit.models import common as _common +from flytekit.models import literals as _literals +from flytekit.models import types as _types -class Variable(_common.FlyteIdlEntity): +class Variable(_common.FlyteIdlEntity): def __init__(self, type, description): """ :param flytekit.models.types.LiteralType type: This describes the type of value that must be provided to @@ -37,10 +39,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.interface_pb2.Variable """ - return _interface_pb2.Variable( - type=self.type.to_flyte_idl(), - description=self.description - ) + return _interface_pb2.Variable(type=self.type.to_flyte_idl(), description=self.description) @classmethod def from_flyte_idl(cls, variable_proto): @@ -48,14 +47,10 @@ def from_flyte_idl(cls, variable_proto): :param flyteidl.core.interface_pb2.Variable variable_proto: :rtype: Variable """ - return cls( - type=_types.LiteralType.from_flyte_idl(variable_proto.type), - description=variable_proto.description - ) + return cls(type=_types.LiteralType.from_flyte_idl(variable_proto.type), description=variable_proto.description,) class VariableMap(_common.FlyteIdlEntity): - def __init__(self, variables): """ A map of Variables @@ -75,9 +70,7 @@ def to_flyte_idl(self): """ :rtype: dict[Text, Variable] """ - return _interface_pb2.VariableMap( - variables={k: v.to_flyte_idl() for k, v in _six.iteritems(self.variables)} - ) + return _interface_pb2.VariableMap(variables={k: v.to_flyte_idl() for k, v in _six.iteritems(self.variables)}) @classmethod def from_flyte_idl(cls, pb2_object): @@ -85,12 +78,10 @@ def from_flyte_idl(cls, pb2_object): :param dict[Text, Variable] pb2_object: :rtype: VariableMap """ - return cls({k: Variable.from_flyte_idl(v) - for k, v in _six.iteritems(pb2_object.variables)}) + return cls({k: Variable.from_flyte_idl(v) for k, v in _six.iteritems(pb2_object.variables)}) class TypedInterface(_common.FlyteIdlEntity): - def __init__(self, inputs, outputs): """ Please note that this model is slightly incorrect, but is more user-friendly. The underlying inputs and @@ -121,12 +112,10 @@ def to_flyte_idl(self): :rtype: flyteidl.core.interface_pb2.TypedInterface """ return _interface_pb2.TypedInterface( - inputs=_interface_pb2.VariableMap( - variables={k: v.to_flyte_idl() for k, v in _six.iteritems(self.inputs)} - ), + inputs=_interface_pb2.VariableMap(variables={k: v.to_flyte_idl() for k, v in _six.iteritems(self.inputs)}), outputs=_interface_pb2.VariableMap( variables={k: v.to_flyte_idl() for k, v in _six.iteritems(self.outputs)} - ) + ), ) @classmethod @@ -137,12 +126,11 @@ def from_flyte_idl(cls, proto): """ return cls( inputs={k: Variable.from_flyte_idl(v) for k, v in _six.iteritems(proto.inputs.variables)}, - outputs={k: Variable.from_flyte_idl(v) for k, v in _six.iteritems(proto.outputs.variables)} + outputs={k: Variable.from_flyte_idl(v) for k, v in _six.iteritems(proto.outputs.variables)}, ) class Parameter(_common.FlyteIdlEntity): - def __init__(self, var, default=None, required=None): """ Declares an input parameter. A parameter is used as input to a launch plan and has @@ -206,12 +194,11 @@ def from_flyte_idl(cls, pb2_object): return cls( Variable.from_flyte_idl(pb2_object.var), _literals.Literal.from_flyte_idl(pb2_object.default) if pb2_object.HasField("default") else None, - pb2_object.required if pb2_object.HasField("required") else None + pb2_object.required if pb2_object.HasField("required") else None, ) class ParameterMap(_common.FlyteIdlEntity): - def __init__(self, parameters): """ A map of Parameters diff --git a/flytekit/models/launch_plan.py b/flytekit/models/launch_plan.py index b814d1aa25..373ffc0351 100644 --- a/flytekit/models/launch_plan.py +++ b/flytekit/models/launch_plan.py @@ -2,12 +2,14 @@ from flyteidl.admin import launch_plan_pb2 as _launch_plan -from flytekit.models import common as _common, interface as _interface, literals as _literals, schedule as _schedule +from flytekit.models import common as _common +from flytekit.models import interface as _interface +from flytekit.models import literals as _literals +from flytekit.models import schedule as _schedule from flytekit.models.core import identifier as _identifier class LaunchPlanMetadata(_common.FlyteIdlEntity): - def __init__(self, schedule, notifications): """ @@ -41,7 +43,7 @@ def to_flyte_idl(self): """ return _launch_plan.LaunchPlanMetadata( schedule=self.schedule.to_flyte_idl() if self.schedule is not None else None, - notifications=[n.to_flyte_idl() for n in self.notifications] + notifications=[n.to_flyte_idl() for n in self.notifications], ) @classmethod @@ -50,9 +52,12 @@ def from_flyte_idl(cls, pb2_object): :param flyteidl.admin.launch_plan_pb2.LaunchPlanMetadata pb2_object: :rtype: LaunchPlanMetadata """ - return cls(schedule=_schedule.Schedule.from_flyte_idl(pb2_object.schedule) if pb2_object.HasField("schedule") - else None, - notifications=[_common.Notification.from_flyte_idl(n) for n in pb2_object.notifications]) + return cls( + schedule=_schedule.Schedule.from_flyte_idl(pb2_object.schedule) + if pb2_object.HasField("schedule") + else None, + notifications=[_common.Notification.from_flyte_idl(n) for n in pb2_object.notifications], + ) class Auth(_common.FlyteIdlEntity): @@ -102,15 +107,24 @@ def from_flyte_idl(cls, pb2_object): """ return cls( assumable_iam_role=pb2_object.assumable_iam_role if pb2_object.HasField("assumable_iam_role") else None, - kubernetes_service_account=pb2_object.kubernetes_service_account if - pb2_object.HasField("kubernetes_service_account") else None, + kubernetes_service_account=pb2_object.kubernetes_service_account + if pb2_object.HasField("kubernetes_service_account") + else None, ) class LaunchPlanSpec(_common.FlyteIdlEntity): - - def __init__(self, workflow_id, entity_metadata, default_inputs, fixed_inputs, labels, annotations, auth_role, - raw_output_data_config): + def __init__( + self, + workflow_id, + entity_metadata, + default_inputs, + fixed_inputs, + labels, + annotations, + auth_role, + raw_output_data_config, + ): """ The spec for a Launch Plan. @@ -250,7 +264,6 @@ def enum_to_string(cls, val): class LaunchPlanClosure(_common.FlyteIdlEntity): - def __init__(self, state, expected_inputs, expected_outputs): """ :param LaunchPlanState state: Indicate the Launch plan phase @@ -307,13 +320,7 @@ def from_flyte_idl(cls, pb2_object): class LaunchPlan(_common.FlyteIdlEntity): - - def __init__( - self, - id, - spec, - closure - ): + def __init__(self, id, spec, closure): """ :param flytekit.models.core.identifier.Identifier id: :param LaunchPlanSpec spec: @@ -349,9 +356,7 @@ def to_flyte_idl(self): :rtype: flyteidl.admin.launch_plan_pb2.LaunchPlan """ return _launch_plan.LaunchPlan( - id=self.id.to_flyte_idl(), - spec=self.spec.to_flyte_idl(), - closure=self.closure.to_flyte_idl() + id=self.id.to_flyte_idl(), spec=self.spec.to_flyte_idl(), closure=self.closure.to_flyte_idl(), ) @classmethod @@ -363,5 +368,5 @@ def from_flyte_idl(cls, pb2_object): return cls( id=_identifier.Identifier.from_flyte_idl(pb2_object.id), spec=LaunchPlanSpec.from_flyte_idl(pb2_object.spec), - closure=LaunchPlanClosure.from_flyte_idl(pb2_object.closure) + closure=LaunchPlanClosure.from_flyte_idl(pb2_object.closure), ) diff --git a/flytekit/models/literals.py b/flytekit/models/literals.py index db91bbce4d..2cf8f4008a 100644 --- a/flytekit/models/literals.py +++ b/flytekit/models/literals.py @@ -1,18 +1,19 @@ from __future__ import absolute_import +from datetime import datetime as _datetime + import pytz as _pytz import six as _six -from datetime import datetime as _datetime from flyteidl.core import literals_pb2 as _literals_pb2 from flytekit.common.exceptions import user as _user_exceptions from flytekit.models import common as _common -from flytekit.models.types import SchemaType as _SchemaType, OutputReference as _OutputReference from flytekit.models.core import types as _core_types +from flytekit.models.types import OutputReference as _OutputReference +from flytekit.models.types import SchemaType as _SchemaType class RetryStrategy(_common.FlyteIdlEntity): - def __init__(self, retries): """ :param int retries: Number of retries to attempt on recoverable failures. If retries is 0, then @@ -40,14 +41,13 @@ def from_flyte_idl(cls, pb2_object): :param flyteidl.core.literals_pb2.RetryStrategy pb2_object: :rtype: RetryStrategy """ - return cls( - retries=pb2_object.retries - ) + return cls(retries=pb2_object.retries) class Primitive(_common.FlyteIdlEntity): - - def __init__(self, integer=None, float_value=None, string_value=None, boolean=None, datetime=None, duration=None): + def __init__( + self, integer=None, float_value=None, string_value=None, boolean=None, datetime=None, duration=None, + ): """ This object proxies the primitives supported by the Flyte IDL system. Only one value can be set. :param int integer: [Optional] @@ -119,7 +119,14 @@ def value(self): This returns whichever field is set. :rtype: T """ - for value in [self.integer, self.float_value, self.string_value, self.boolean, self.datetime, self.duration]: + for value in [ + self.integer, + self.float_value, + self.string_value, + self.boolean, + self.datetime, + self.duration, + ]: if value is not None: return value @@ -128,10 +135,7 @@ def to_flyte_idl(self): :rtype: flyteidl.core.literals_pb2.Primitive """ primitive = _literals_pb2.Primitive( - integer=self.integer, - float_value=self.float_value, - string_value=self.string_value, - boolean=self.boolean + integer=self.integer, float_value=self.float_value, string_value=self.string_value, boolean=self.boolean, ) if self.datetime is not None: # Convert to UTC and remove timezone so protobuf behaves. @@ -152,12 +156,11 @@ def from_flyte_idl(cls, proto): string_value=proto.string_value if proto.HasField("string_value") else None, boolean=proto.boolean if proto.HasField("boolean") else None, datetime=proto.datetime.ToDatetime().replace(tzinfo=_pytz.UTC) if proto.HasField("datetime") else None, - duration=proto.duration.ToTimedelta() if proto.HasField("duration") else None + duration=proto.duration.ToTimedelta() if proto.HasField("duration") else None, ) class Binary(_common.FlyteIdlEntity): - def __init__(self, value, tag): """ :param bytes value: @@ -192,15 +195,13 @@ def from_flyte_idl(cls, pb2_object): :param flyteidl.core.literals_pb2.Binary pb2_object: :rtype: Binary """ - return cls( - value=pb2_object.value, - tag=pb2_object.tag - ) + return cls(value=pb2_object.value, tag=pb2_object.tag) class Scalar(_common.FlyteIdlEntity): - - def __init__(self, primitive=None, blob=None, binary=None, schema=None, none_type=None, error=None, generic=None): + def __init__( + self, primitive=None, blob=None, binary=None, schema=None, none_type=None, error=None, generic=None, + ): """ Scalar wrapper around Flyte types. Only one can be specified. @@ -311,7 +312,6 @@ def from_flyte_idl(cls, pb2_object): class BlobMetadata(_common.FlyteIdlEntity): - def __init__(self, type): """ :param flytekit.models.core.types.BlobType type: The type of the underlying blob @@ -341,7 +341,6 @@ def from_flyte_idl(cls, proto): class Blob(_common.FlyteIdlEntity): - def __init__(self, metadata, uri): """ :param BlobMetadata metadata: @@ -368,10 +367,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.literals_pb2.Blob """ - return _literals_pb2.Blob( - metadata=self.metadata.to_flyte_idl(), - uri=self.uri - ) + return _literals_pb2.Blob(metadata=self.metadata.to_flyte_idl(), uri=self.uri) @classmethod def from_flyte_idl(cls, proto): @@ -379,14 +375,10 @@ def from_flyte_idl(cls, proto): :param flyteidl.core.literals_pb2.Blob proto: :rtype: Blob """ - return cls( - metadata=BlobMetadata.from_flyte_idl(proto.metadata), - uri=proto.uri - ) + return cls(metadata=BlobMetadata.from_flyte_idl(proto.metadata), uri=proto.uri) class Void(_common.FlyteIdlEntity): - def to_flyte_idl(self): """ :rtype: flyteidl.core.literals_pb2.Void @@ -403,7 +395,6 @@ def from_flyte_idl(cls, proto): class BindingDataMap(_common.FlyteIdlEntity): - def __init__(self, bindings): """ A map of BindingData items. Can be a recursive structure @@ -424,9 +415,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.literals_pb2.BindingDataMap """ - return _literals_pb2.BindingDataMap( - bindings={k: v.to_flyte_idl() for (k, v) in _six.iteritems(self.bindings)} - ) + return _literals_pb2.BindingDataMap(bindings={k: v.to_flyte_idl() for (k, v) in _six.iteritems(self.bindings)}) @classmethod def from_flyte_idl(cls, pb2_object): @@ -439,7 +428,6 @@ def from_flyte_idl(cls, pb2_object): class BindingDataCollection(_common.FlyteIdlEntity): - def __init__(self, bindings): """ A list of BindingData items. @@ -471,7 +459,6 @@ def from_flyte_idl(cls, pb2_object): class BindingData(_common.FlyteIdlEntity): - def __init__(self, scalar=None, collection=None, promise=None, map=None): """ Specifies either a simple value or a reference to another output. Only one of the input arguments may be @@ -536,7 +523,7 @@ def to_flyte_idl(self): scalar=self.scalar.to_flyte_idl() if self.scalar is not None else None, collection=self.collection.to_flyte_idl() if self.collection is not None else None, promise=self.promise.to_flyte_idl() if self.promise is not None else None, - map=self.map.to_flyte_idl() if self.map is not None else None + map=self.map.to_flyte_idl() if self.map is not None else None, ) @classmethod @@ -548,9 +535,10 @@ def from_flyte_idl(cls, pb2_object): return cls( scalar=Scalar.from_flyte_idl(pb2_object.scalar) if pb2_object.HasField("scalar") else None, collection=BindingDataCollection.from_flyte_idl(pb2_object.collection) - if pb2_object.HasField("collection") else None, + if pb2_object.HasField("collection") + else None, promise=_OutputReference.from_flyte_idl(pb2_object.promise) if pb2_object.HasField("promise") else None, - map=BindingDataMap.from_flyte_idl(pb2_object.map) if pb2_object.HasField("map") else None + map=BindingDataMap.from_flyte_idl(pb2_object.map) if pb2_object.HasField("map") else None, ) def to_literal_model(self): @@ -559,19 +547,22 @@ def to_literal_model(self): :rtype: Literal """ if self.promise: - raise _user_exceptions.FlyteValueException(self.promise, "Cannot convert BindingData to a Literal because " - "it has a promise.") + raise _user_exceptions.FlyteValueException( + self.promise, "Cannot convert BindingData to a Literal because " "it has a promise.", + ) elif self.scalar: return Literal(scalar=self.scalar) elif self.collection: - return Literal(collection=LiteralCollection( - literals=[binding.to_literal_model() for binding in self.collection.bindings])) + return Literal( + collection=LiteralCollection( + literals=[binding.to_literal_model() for binding in self.collection.bindings] + ) + ) elif self.map: return Literal(map=LiteralMap(literals={k: binding.to_literal_model() for k, binding in self.map.bindings})) class Binding(_common.FlyteIdlEntity): - def __init__(self, var, binding): """ An input/output binding of a variable to either static value or a node output. @@ -602,10 +593,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.literals_pb2.Binding """ - return _literals_pb2.Binding( - var=self.var, - binding=self.binding.to_flyte_idl() - ) + return _literals_pb2.Binding(var=self.var, binding=self.binding.to_flyte_idl()) @classmethod def from_flyte_idl(cls, pb2_object): @@ -617,7 +605,6 @@ def from_flyte_idl(cls, pb2_object): class Schema(_common.FlyteIdlEntity): - def __init__(self, uri, type): """ A strongly typed schema that defines the interface of data retrieved from the underlying storage medium. @@ -658,7 +645,6 @@ def from_flyte_idl(cls, pb2_object): class LiteralCollection(_common.FlyteIdlEntity): - def __init__(self, literals): """ :param list[Literal] literals: underlying list of literals in this collection. @@ -676,9 +662,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.literals_pb2.LiteralCollection """ - return _literals_pb2.LiteralCollection( - literals=[l.to_flyte_idl() for l in self.literals] - ) + return _literals_pb2.LiteralCollection(literals=[l.to_flyte_idl() for l in self.literals]) @classmethod def from_flyte_idl(cls, pb2_object): @@ -690,7 +674,6 @@ def from_flyte_idl(cls, pb2_object): class LiteralMap(_common.FlyteIdlEntity): - def __init__(self, literals): """ :param dict[Text, Literal] literals: A dictionary mapping Text key names to Literal objects. @@ -709,9 +692,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.literals_pb2.LiteralMap """ - return _literals_pb2.LiteralMap( - literals={k: v.to_flyte_idl() for k, v in _six.iteritems(self.literals)} - ) + return _literals_pb2.LiteralMap(literals={k: v.to_flyte_idl() for k, v in _six.iteritems(self.literals)}) @classmethod def from_flyte_idl(cls, pb2_object): @@ -723,7 +704,6 @@ def from_flyte_idl(cls, pb2_object): class Literal(_common.FlyteIdlEntity): - def __init__(self, scalar=None, collection=None, map=None): """ :param Scalar scalar: @@ -773,7 +753,7 @@ def to_flyte_idl(self): return _literals_pb2.Literal( scalar=self.scalar.to_flyte_idl() if self.scalar is not None else None, collection=self.collection.to_flyte_idl() if self.collection is not None else None, - map=self.map.to_flyte_idl() if self.map is not None else None + map=self.map.to_flyte_idl() if self.map is not None else None, ) @classmethod @@ -789,5 +769,5 @@ def from_flyte_idl(cls, pb2_object): return cls( scalar=Scalar.from_flyte_idl(pb2_object.scalar) if pb2_object.HasField("scalar") else None, collection=collection, - map=LiteralMap.from_flyte_idl(pb2_object.map) if pb2_object.HasField("map") else None + map=LiteralMap.from_flyte_idl(pb2_object.map) if pb2_object.HasField("map") else None, ) diff --git a/flytekit/models/matchable_resource.py b/flytekit/models/matchable_resource.py index 46171b0c19..16fd502a59 100644 --- a/flytekit/models/matchable_resource.py +++ b/flytekit/models/matchable_resource.py @@ -1,9 +1,9 @@ from flyteidl.admin import matchable_resource_pb2 as _matchable_resource + from flytekit.models import common as _common class ClusterResourceAttributes(_common.FlyteIdlEntity): - def __init__(self, attributes): """ Custom resource attributes which will be applied in cluster resource creation (e.g. quotas). @@ -26,9 +26,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.matchable_resource_pb2.ClusterResourceAttributes """ - return _matchable_resource.ClusterResourceAttributes( - attributes=self.attributes, - ) + return _matchable_resource.ClusterResourceAttributes(attributes=self.attributes,) @classmethod def from_flyte_idl(cls, pb2_object): @@ -36,13 +34,10 @@ def from_flyte_idl(cls, pb2_object): :param flyteidl.admin.matchable_resource_pb2.ClusterResourceAttributes pb2_object: :rtype: ClusterResourceAttributes """ - return cls( - attributes=pb2_object.attributes, - ) + return cls(attributes=pb2_object.attributes,) class ExecutionQueueAttributes(_common.FlyteIdlEntity): - def __init__(self, tags): """ Tags used for assigning execution queues for tasks matching a project, domain and optionally, workflow. @@ -62,9 +57,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.matchable_resource_pb2.ExecutionQueueAttributes """ - return _matchable_resource.ExecutionQueueAttributes( - tags=self.tags, - ) + return _matchable_resource.ExecutionQueueAttributes(tags=self.tags,) @classmethod def from_flyte_idl(cls, pb2_object): @@ -72,13 +65,10 @@ def from_flyte_idl(cls, pb2_object): :param flyteidl.admin.matchable_resource_pb2.ExecutionQueueAttributes pb2_object: :rtype: ExecutionQueueAttributes """ - return cls( - tags=pb2_object.tags, - ) + return cls(tags=pb2_object.tags,) class ExecutionClusterLabel(_common.FlyteIdlEntity): - def __init__(self, value): """ Label value to determine where the execution will be run @@ -98,9 +88,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.matchable_resource_pb2.ExecutionClusterLabel """ - return _matchable_resource.ExecutionClusterLabel( - value=self.value, - ) + return _matchable_resource.ExecutionClusterLabel(value=self.value,) @classmethod def from_flyte_idl(cls, pb2_object): @@ -108,13 +96,13 @@ def from_flyte_idl(cls, pb2_object): :param flyteidl.admin.matchable_resource_pb2.ExecutionClusterLabel pb2_object: :rtype: ExecutionClusterLabel """ - return cls( - value=pb2_object.value, - ) + return cls(value=pb2_object.value,) class MatchingAttributes(_common.FlyteIdlEntity): - def __init__(self, cluster_resource_attributes=None, execution_queue_attributes=None, execution_cluster_label=None): + def __init__( + self, cluster_resource_attributes=None, execution_queue_attributes=None, execution_cluster_label=None, + ): """ At most one target from cluster_resource_attributes, execution_queue_attributes or execution_cluster_label can be set. @@ -161,11 +149,14 @@ def to_flyte_idl(self): :rtype: flyteidl.admin.matchable_resource_pb2.MatchingAttributes """ return _matchable_resource.MatchingAttributes( - cluster_resource_attributes=self.cluster_resource_attributes.to_flyte_idl() if - self.cluster_resource_attributes else None, - execution_queue_attributes=self.execution_queue_attributes.to_flyte_idl() if self.execution_queue_attributes + cluster_resource_attributes=self.cluster_resource_attributes.to_flyte_idl() + if self.cluster_resource_attributes else None, - execution_cluster_label=self.execution_cluster_label.to_flyte_idl() if self.execution_cluster_label + execution_queue_attributes=self.execution_queue_attributes.to_flyte_idl() + if self.execution_queue_attributes + else None, + execution_cluster_label=self.execution_cluster_label.to_flyte_idl() + if self.execution_cluster_label else None, ) @@ -176,10 +167,13 @@ def from_flyte_idl(cls, pb2_object): :rtype: MatchingAttributes """ return cls( - cluster_resource_attributes=ClusterResourceAttributes.from_flyte_idl( - pb2_object.cluster_resource_attributes) if pb2_object.HasField("cluster_resource_attributes") else None, - execution_queue_attributes=ExecutionQueueAttributes.from_flyte_idl(pb2_object.execution_queue_attributes) if - pb2_object.HasField("execution_queue_attributes") else None, - execution_cluster_label=ExecutionClusterLabel.from_flyte_idl(pb2_object.execution_cluster_label) if - pb2_object.HasField("execution_cluster_label") else None, + cluster_resource_attributes=ClusterResourceAttributes.from_flyte_idl(pb2_object.cluster_resource_attributes) + if pb2_object.HasField("cluster_resource_attributes") + else None, + execution_queue_attributes=ExecutionQueueAttributes.from_flyte_idl(pb2_object.execution_queue_attributes) + if pb2_object.HasField("execution_queue_attributes") + else None, + execution_cluster_label=ExecutionClusterLabel.from_flyte_idl(pb2_object.execution_cluster_label) + if pb2_object.HasField("execution_cluster_label") + else None, ) diff --git a/flytekit/models/named_entity.py b/flytekit/models/named_entity.py index 63dd598d98..80d70aa35c 100644 --- a/flytekit/models/named_entity.py +++ b/flytekit/models/named_entity.py @@ -57,11 +57,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.common_pb2.NamedEntityIdentifier """ - return _common.NamedEntityIdentifier( - project=self.project, - domain=self.domain, - name=self.name, - ) + return _common.NamedEntityIdentifier(project=self.project, domain=self.domain, name=self.name,) @classmethod def from_flyte_idl(cls, p): @@ -69,11 +65,7 @@ def from_flyte_idl(cls, p): :param flyteidl.core.common_pb2.NamedEntityIdentifier p: :rtype: Identifier """ - return cls( - project=p.project, - domain=p.domain, - name=p.name, - ) + return cls(project=p.project, domain=p.domain, name=p.name,) class NamedEntityMetadata(_common_models.FlyteIdlEntity): @@ -105,10 +97,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.common_pb2.NamedEntityMetadata """ - return _common.NamedEntityMetadata( - description=self.description, - state=self.state, - ) + return _common.NamedEntityMetadata(description=self.description, state=self.state,) @classmethod def from_flyte_idl(cls, p): @@ -116,7 +105,4 @@ def from_flyte_idl(cls, p): :param flyteidl.core.common_pb2.NamedEntityMetadata p: :rtype: Identifier """ - return cls( - description=p.description, - state=p.state, - ) + return cls(description=p.description, state=p.state,) diff --git a/flytekit/models/node_execution.py b/flytekit/models/node_execution.py index 16480808ac..f47d71b45d 100644 --- a/flytekit/models/node_execution.py +++ b/flytekit/models/node_execution.py @@ -1,20 +1,15 @@ from __future__ import absolute_import -from flytekit.models import common as _common_models -from flytekit.models.core import execution as _core_execution, identifier as _identifier + import flyteidl.admin.node_execution_pb2 as _node_execution_pb2 import pytz as _pytz +from flytekit.models import common as _common_models +from flytekit.models.core import execution as _core_execution +from flytekit.models.core import identifier as _identifier + class NodeExecutionClosure(_common_models.FlyteIdlEntity): - - def __init__( - self, - phase, - started_at, - duration, - output_uri=None, - error=None - ): + def __init__(self, phase, started_at, duration, output_uri=None, error=None): """ :param int phase: :param datetime.datetime started_at: @@ -70,7 +65,7 @@ def to_flyte_idl(self): obj = _node_execution_pb2.NodeExecutionClosure( phase=self.phase, output_uri=self.output_uri, - error=self.error.to_flyte_idl() if self.error is not None else None + error=self.error.to_flyte_idl() if self.error is not None else None, ) obj.started_at.FromDatetime(self.started_at.astimezone(_pytz.UTC).replace(tzinfo=None)) obj.duration.FromTimedelta(self.duration) @@ -87,18 +82,12 @@ def from_flyte_idl(cls, p): output_uri=p.output_uri if p.HasField("output_uri") else None, error=_core_execution.ExecutionError.from_flyte_idl(p.error) if p.HasField("error") else None, started_at=p.started_at.ToDatetime().replace(tzinfo=_pytz.UTC), - duration=p.duration.ToTimedelta() + duration=p.duration.ToTimedelta(), ) class NodeExecution(_common_models.FlyteIdlEntity): - - def __init__( - self, - id, - input_uri, - closure - ): + def __init__(self, id, input_uri, closure): """ :param flytekit.models.core.identifier.NodeExecutionIdentifier id: :param Text input_uri: @@ -134,9 +123,7 @@ def to_flyte_idl(self): :rtype: flyteidl.admin.node_execution_pb2.NodeExecution """ return _node_execution_pb2.NodeExecution( - id=self.id.to_flyte_idl(), - input_uri=self.input_uri, - closure=self.closure.to_flyte_idl() + id=self.id.to_flyte_idl(), input_uri=self.input_uri, closure=self.closure.to_flyte_idl(), ) @classmethod @@ -148,5 +135,5 @@ def from_flyte_idl(cls, p): return cls( id=_identifier.NodeExecutionIdentifier.from_flyte_idl(p.id), input_uri=p.input_uri, - closure=NodeExecutionClosure.from_flyte_idl(p.closure) + closure=NodeExecutionClosure.from_flyte_idl(p.closure), ) diff --git a/flytekit/models/presto.py b/flytekit/models/presto.py index 9ccb184b99..13edc42b9f 100644 --- a/flytekit/models/presto.py +++ b/flytekit/models/presto.py @@ -54,13 +54,10 @@ def statement(self): def to_flyte_idl(self): """ - :rtype: _presto.PrestoQuery + :rtype: _presto.PrestoQuery """ return _presto.PrestoQuery( - routing_group=self._routing_group, - catalog=self._catalog, - schema=self._schema, - statement=self._statement + routing_group=self._routing_group, catalog=self._catalog, schema=self._schema, statement=self._statement, ) @classmethod @@ -73,5 +70,5 @@ def from_flyte_idl(cls, pb2_object): routing_group=pb2_object.routing_group, catalog=pb2_object.catalog, schema=pb2_object.schema, - statement=pb2_object.statement + statement=pb2_object.statement, ) diff --git a/flytekit/models/project.py b/flytekit/models/project.py index 03a5f808c1..80d80e35b8 100644 --- a/flytekit/models/project.py +++ b/flytekit/models/project.py @@ -6,7 +6,6 @@ class Project(_common.FlyteIdlEntity): - def __init__(self, id, name, description): """ A project represents a logical grouping used to organize entities (tasks, workflows, executions) in the Flyte @@ -48,11 +47,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.project_pb2.Project """ - return _project_pb2.Project( - id=self.id, - name=self.name, - description=self.description, - ) + return _project_pb2.Project(id=self.id, name=self.name, description=self.description,) @classmethod def from_flyte_idl(cls, pb2_object): @@ -60,8 +55,4 @@ def from_flyte_idl(cls, pb2_object): :param flyteidl.admin.project_pb2.Project pb2_object: :rtype: Project """ - return cls( - id=pb2_object.id, - name=pb2_object.name, - description=pb2_object.description, - ) + return cls(id=pb2_object.id, name=pb2_object.name, description=pb2_object.description,) diff --git a/flytekit/models/qubole.py b/flytekit/models/qubole.py index ea6151e9bc..1e421ec89e 100644 --- a/flytekit/models/qubole.py +++ b/flytekit/models/qubole.py @@ -45,11 +45,7 @@ def to_flyte_idl(self): """ :rtype: _qubole.HiveQuery """ - return _qubole.HiveQuery( - query=self.query, - timeout_sec=self.timeout_sec, - retryCount=self.retry_count - ) + return _qubole.HiveQuery(query=self.query, timeout_sec=self.timeout_sec, retryCount=self.retry_count) @classmethod def from_flyte_idl(cls, pb2_object): @@ -57,11 +53,8 @@ def from_flyte_idl(cls, pb2_object): :param _qubole.HiveQuery pb2_object: :return: HiveQuery """ - return cls( - query=pb2_object.query, - timeout_sec=pb2_object.timeout_sec, - retry_count=pb2_object.retryCount - ) + return cls(query=pb2_object.query, timeout_sec=pb2_object.timeout_sec, retry_count=pb2_object.retryCount,) + class HiveQueryCollection(_common.FlyteIdlEntity): def __init__(self, queries): @@ -93,13 +86,10 @@ def from_flyte_idl(cls, pb2_object): :param _qubole.HiveQuery pb2_object: :rtype: HiveQueryCollection """ - return cls( - queries=[HiveQuery.from_flyte_idl(query) for query in pb2_object.queries] - ) + return cls(queries=[HiveQuery.from_flyte_idl(query) for query in pb2_object.queries]) class QuboleHiveJob(_common.FlyteIdlEntity): - def __init__(self, query, cluster_label, tags, query_collection=None): """ Initializes a HiveJob. @@ -130,7 +120,6 @@ def query(self): """ return self._query - @property def cluster_label(self): """ @@ -155,7 +144,7 @@ def to_flyte_idl(self): query_collection=self._query_collection.to_flyte_idl() if self._query_collection else None, query=self._query.to_flyte_idl() if self._query else None, cluster_label=self._cluster_label, - tags=self._tags + tags=self._tags, ) @classmethod @@ -165,8 +154,9 @@ def from_flyte_idl(cls, p): :rtype: QuboleHiveJob """ return cls( - query_collection=HiveQueryCollection.from_flyte_idl(p.query_collection) if p.HasField( - "query_collection") else None, + query_collection=HiveQueryCollection.from_flyte_idl(p.query_collection) + if p.HasField("query_collection") + else None, query=HiveQuery.from_flyte_idl(p.query) if p.HasField("query") else None, cluster_label=p.cluster_label, tags=p.tags, diff --git a/flytekit/models/sagemaker/hpo_job.py b/flytekit/models/sagemaker/hpo_job.py index 6ecaf330a5..6fb0b1ebd6 100644 --- a/flytekit/models/sagemaker/hpo_job.py +++ b/flytekit/models/sagemaker/hpo_job.py @@ -3,7 +3,8 @@ from flyteidl.plugins.sagemaker import hyperparameter_tuning_job_pb2 as _pb2_hpo_job from flytekit.models import common as _common -from flytekit.models.sagemaker import parameter_ranges as _parameter_ranges_models, training_job as _training_job +from flytekit.models.sagemaker import parameter_ranges as _parameter_ranges_models +from flytekit.models.sagemaker import training_job as _training_job class HyperparameterTuningObjectiveType(object): @@ -18,10 +19,9 @@ class HyperparameterTuningObjective(_common.FlyteIdlEntity): https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-define-metrics.html """ + def __init__( - self, - objective_type: int, - metric_name: str, + self, objective_type: int, metric_name: str, ): self._objective_type = objective_type self._metric_name = metric_name @@ -47,17 +47,13 @@ def metric_name(self) -> str: def to_flyte_idl(self) -> _pb2_hpo_job.HyperparameterTuningObjective: return _pb2_hpo_job.HyperparameterTuningObjective( - objective_type=self.objective_type, - metric_name=self._metric_name, + objective_type=self.objective_type, metric_name=self._metric_name, ) @classmethod def from_flyte_idl(cls, pb2_object: _pb2_hpo_job.HyperparameterTuningObjective): - return cls( - objective_type=pb2_object.objective_type, - metric_name=pb2_object.metric_name, - ) + return cls(objective_type=pb2_object.objective_type, metric_name=pb2_object.metric_name,) class HyperparameterTuningStrategy: @@ -75,12 +71,13 @@ class HyperparameterTuningJobConfig(_common.FlyteIdlEntity): The specification of the hyperparameter tuning process https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-ex-tuning-job.html#automatic-model-tuning-ex-low-tuning-config """ + def __init__( - self, - hyperparameter_ranges: _parameter_ranges_models.ParameterRanges, - tuning_strategy: int, - tuning_objective: HyperparameterTuningObjective, - training_job_early_stopping_type: int, + self, + hyperparameter_ranges: _parameter_ranges_models.ParameterRanges, + tuning_strategy: int, + tuning_objective: HyperparameterTuningObjective, + training_job_early_stopping_type: int, ): self._hyperparameter_ranges = hyperparameter_ranges self._tuning_strategy = tuning_strategy @@ -138,7 +135,8 @@ def from_flyte_idl(cls, pb2_object: _pb2_hpo_job.HyperparameterTuningJobConfig): return cls( hyperparameter_ranges=( - _parameter_ranges_models.ParameterRanges.from_flyte_idl(pb2_object.hyperparameter_ranges)), + _parameter_ranges_models.ParameterRanges.from_flyte_idl(pb2_object.hyperparameter_ranges) + ), tuning_strategy=pb2_object.tuning_strategy, tuning_objective=HyperparameterTuningObjective.from_flyte_idl(pb2_object.tuning_objective), training_job_early_stopping_type=pb2_object.training_job_early_stopping_type, @@ -146,12 +144,11 @@ def from_flyte_idl(cls, pb2_object: _pb2_hpo_job.HyperparameterTuningJobConfig): class HyperparameterTuningJob(_common.FlyteIdlEntity): - def __init__( - self, - max_number_of_training_jobs: int, - max_parallel_training_jobs: int, - training_job: _training_job.TrainingJob, + self, + max_number_of_training_jobs: int, + max_parallel_training_jobs: int, + training_job: _training_job.TrainingJob, ): self._max_number_of_training_jobs = max_number_of_training_jobs self._max_parallel_training_jobs = max_parallel_training_jobs diff --git a/flytekit/models/sagemaker/parameter_ranges.py b/flytekit/models/sagemaker/parameter_ranges.py index 4c7dec2c9b..62b02e5c90 100644 --- a/flytekit/models/sagemaker/parameter_ranges.py +++ b/flytekit/models/sagemaker/parameter_ranges.py @@ -16,10 +16,7 @@ class HyperparameterScalingType(object): class ContinuousParameterRange(_common.FlyteIdlEntity): def __init__( - self, - max_value: float, - min_value: float, - scaling_type: int, + self, max_value: float, min_value: float, scaling_type: int, ): """ @@ -61,9 +58,7 @@ def to_flyte_idl(self) -> _idl_parameter_ranges.ContinuousParameterRange: """ return _idl_parameter_ranges.ContinuousParameterRange( - max_value=self._max_value, - min_value=self._min_value, - scaling_type=self.scaling_type, + max_value=self._max_value, min_value=self._min_value, scaling_type=self.scaling_type, ) @classmethod @@ -74,18 +69,13 @@ def from_flyte_idl(cls, pb2_object: _idl_parameter_ranges.ContinuousParameterRan :rtype: ContinuousParameterRange """ return cls( - max_value=pb2_object.max_value, - min_value=pb2_object.min_value, - scaling_type=pb2_object.scaling_type, + max_value=pb2_object.max_value, min_value=pb2_object.min_value, scaling_type=pb2_object.scaling_type, ) class IntegerParameterRange(_common.FlyteIdlEntity): def __init__( - self, - max_value: int, - min_value: int, - scaling_type: int, + self, max_value: int, min_value: int, scaling_type: int, ): """ :param int max_value: @@ -124,9 +114,7 @@ def to_flyte_idl(self) -> _idl_parameter_ranges.IntegerParameterRange: :rtype: _idl_parameter_ranges.IntegerParameterRange """ return _idl_parameter_ranges.IntegerParameterRange( - max_value=self._max_value, - min_value=self._min_value, - scaling_type=self.scaling_type, + max_value=self._max_value, min_value=self._min_value, scaling_type=self.scaling_type, ) @classmethod @@ -137,16 +125,13 @@ def from_flyte_idl(cls, pb2_object: _idl_parameter_ranges.IntegerParameterRange) :rtype: IntegerParameterRange """ return cls( - max_value=pb2_object.max_value, - min_value=pb2_object.min_value, - scaling_type=pb2_object.scaling_type, + max_value=pb2_object.max_value, min_value=pb2_object.min_value, scaling_type=pb2_object.scaling_type, ) class CategoricalParameterRange(_common.FlyteIdlEntity): def __init__( - self, - values: List[str], + self, values: List[str], ): """ @@ -165,21 +150,16 @@ def to_flyte_idl(self) -> _idl_parameter_ranges.CategoricalParameterRange: """ :rtype: _idl_parameter_ranges.CategoricalParameterRange """ - return _idl_parameter_ranges.CategoricalParameterRange( - values=self._values - ) + return _idl_parameter_ranges.CategoricalParameterRange(values=self._values) @classmethod def from_flyte_idl(cls, pb2_object: _idl_parameter_ranges.CategoricalParameterRange): - return cls( - values=pb2_object.values - ) + return cls(values=pb2_object.values) class ParameterRanges(_common.FlyteIdlEntity): def __init__( - self, - parameter_range_map: Dict[str, _common.FlyteIdlEntity], + self, parameter_range_map: Dict[str, _common.FlyteIdlEntity], ): self._parameter_range_map = parameter_range_map @@ -193,9 +173,7 @@ def to_flyte_idl(self) -> _idl_parameter_ranges.ParameterRanges: else: converted[k] = _idl_parameter_ranges.ParameterRangeOneOf(categorical_parameter_range=v.to_flyte_idl()) - return _idl_parameter_ranges.ParameterRanges( - parameter_range_map=converted, - ) + return _idl_parameter_ranges.ParameterRanges(parameter_range_map=converted,) @classmethod def from_flyte_idl(cls, pb2_object: _idl_parameter_ranges.ParameterRanges): @@ -208,6 +186,4 @@ def from_flyte_idl(cls, pb2_object: _idl_parameter_ranges.ParameterRanges): else: converted[k] = CategoricalParameterRange.from_flyte_idl(v) - return cls( - parameter_range_map=converted, - ) + return cls(parameter_range_map=converted,) diff --git a/flytekit/models/sagemaker/training_job.py b/flytekit/models/sagemaker/training_job.py index db6e8cade4..6247306fcb 100644 --- a/flytekit/models/sagemaker/training_job.py +++ b/flytekit/models/sagemaker/training_job.py @@ -13,11 +13,9 @@ class TrainingJobResourceConfig(_common.FlyteIdlEntity): number of instances to launch, and the size of the ML storage volume the user wants to provision Refer to SageMaker official doc for more details: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTrainingJob.html """ + def __init__( - self, - instance_count: int, - instance_type: str, - volume_size_in_gb: int, + self, instance_count: int, instance_type: str, volume_size_in_gb: int, ): self._instance_count = instance_count self._instance_type = instance_type @@ -74,9 +72,7 @@ def from_flyte_idl(cls, pb2_object: _training_job_pb2.TrainingJobResourceConfig) class MetricDefinition(_common.FlyteIdlEntity): def __init__( - self, - name: str, - regex: str, + self, name: str, regex: str, ): self._name = name self._regex = regex @@ -103,10 +99,7 @@ def to_flyte_idl(self) -> _training_job_pb2.MetricDefinition: :rtype: _training_job_pb2.MetricDefinition """ - return _training_job_pb2.MetricDefinition( - name=self.name, - regex=self.regex, - ) + return _training_job_pb2.MetricDefinition(name=self.name, regex=self.regex,) @classmethod def from_flyte_idl(cls, pb2_object: _training_job_pb2.MetricDefinition): @@ -115,10 +108,7 @@ def from_flyte_idl(cls, pb2_object: _training_job_pb2.MetricDefinition): :param pb2_object: _training_job_pb2.MetricDefinition :rtype: MetricDefinition """ - return cls( - name=pb2_object.name, - regex=pb2_object.regex, - ) + return cls(name=pb2_object.name, regex=pb2_object.regex,) class InputMode(object): @@ -127,6 +117,7 @@ class InputMode(object): See https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-training.html https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-algo-docker-registry-paths.html """ + PIPE = _training_job_pb2.InputMode.PIPE FILE = _training_job_pb2.InputMode.FILE @@ -138,6 +129,7 @@ class AlgorithmName(object): While we currently only support a subset of the algorithms, more will be added to the list. See: https://docs.aws.amazon.com/sagemaker/latest/dg/algos.html """ + CUSTOM = _training_job_pb2.AlgorithmName.CUSTOM XGBOOST = _training_job_pb2.AlgorithmName.XGBOOST @@ -148,6 +140,7 @@ class InputContentType(object): See https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-training.html https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-algo-docker-registry-paths.html """ + TEXT_CSV = _training_job_pb2.InputContentType.TEXT_CSV @@ -161,13 +154,14 @@ class AlgorithmSpecification(_common.FlyteIdlEntity): For pass-through use cases: refer to this AWS official document for more details https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_AlgorithmSpecification.html """ + def __init__( - self, - algorithm_name: int, - algorithm_version: str, - input_mode: int, - metric_definitions: List[MetricDefinition] = None, - input_content_type: int = InputContentType.TEXT_CSV, + self, + algorithm_name: int, + algorithm_version: str, + input_mode: int, + metric_definitions: List[MetricDefinition] = None, + input_content_type: int = InputContentType.TEXT_CSV, ): self._input_mode = input_mode self._input_content_type = input_content_type @@ -249,9 +243,7 @@ def from_flyte_idl(cls, pb2_object: _training_job_pb2.AlgorithmSpecification): class TrainingJob(_common.FlyteIdlEntity): def __init__( - self, - algorithm_specification: AlgorithmSpecification, - training_job_resource_config: TrainingJobResourceConfig, + self, algorithm_specification: AlgorithmSpecification, training_job_resource_config: TrainingJobResourceConfig, ): self._algorithm_specification = algorithm_specification self._training_job_resource_config = training_job_resource_config diff --git a/flytekit/models/schedule.py b/flytekit/models/schedule.py index da41ea1daa..91cb811824 100644 --- a/flytekit/models/schedule.py +++ b/flytekit/models/schedule.py @@ -27,7 +27,6 @@ def enum_to_string(cls, int_value): return "{}".format(int_value) class FixedRate(_common.FlyteIdlEntity): - def __init__(self, value, unit): """ :param int value: @@ -117,5 +116,5 @@ def from_flyte_idl(cls, pb2_object): return cls( pb2_object.kickoff_time_input_arg, cron_expression=pb2_object.cron_expression if pb2_object.HasField("cron_expression") else None, - rate=Schedule.FixedRate.from_flyte_idl(pb2_object.rate) if pb2_object.HasField("rate") else None + rate=Schedule.FixedRate.from_flyte_idl(pb2_object.rate) if pb2_object.HasField("rate") else None, ) diff --git a/flytekit/models/task.py b/flytekit/models/task.py index bf849f9ecf..eb594c424c 100644 --- a/flytekit/models/task.py +++ b/flytekit/models/task.py @@ -4,15 +4,21 @@ import six as _six from flyteidl.admin import task_pb2 as _admin_task -from flyteidl.core import tasks_pb2 as _core_task, literals_pb2 as _literals_pb2, compiler_pb2 as _compiler -from flyteidl.plugins import spark_pb2 as _spark_task +from flyteidl.core import compiler_pb2 as _compiler +from flyteidl.core import literals_pb2 as _literals_pb2 +from flyteidl.core import tasks_pb2 as _core_task from flyteidl.plugins import pytorch_pb2 as _pytorch_task +from flyteidl.plugins import spark_pb2 as _spark_task +from google.protobuf import json_format as _json_format +from google.protobuf import struct_pb2 as _struct + +from flytekit.common.exceptions import user as _user_exceptions +from flytekit.models import common as _common +from flytekit.models import interface as _interface +from flytekit.models import literals as _literals +from flytekit.models.core import identifier as _identifier from flytekit.plugins import flyteidl as _lazy_flyteidl -from google.protobuf import json_format as _json_format, struct_pb2 as _struct from flytekit.sdk.spark_types import SparkType as _spark_type -from flytekit.models import common as _common, literals as _literals, interface as _interface -from flytekit.models.core import identifier as _identifier -from flytekit.common.exceptions import user as _user_exceptions class Resources(_common.FlyteIdlEntity): @@ -24,7 +30,6 @@ class ResourceName(object): STORAGE = _core_task.Resources.STORAGE class ResourceEntry(_common.FlyteIdlEntity): - def __init__(self, name, value): """ :param int name: enum value from ResourceName @@ -94,8 +99,7 @@ def to_flyte_idl(self): :rtype: flyteidl.core.tasks_pb2.Resources """ return _core_task.Resources( - requests=[r.to_flyte_idl() for r in self.requests], - limits=[r.to_flyte_idl() for r in self.limits] + requests=[r.to_flyte_idl() for r in self.requests], limits=[r.to_flyte_idl() for r in self.limits], ) @classmethod @@ -106,7 +110,7 @@ def from_flyte_idl(cls, pb2_object): """ return cls( requests=[Resources.ResourceEntry.from_flyte_idl(r) for r in pb2_object.requests], - limits=[Resources.ResourceEntry.from_flyte_idl(l) for l in pb2_object.limits] + limits=[Resources.ResourceEntry.from_flyte_idl(l) for l in pb2_object.limits], ) @@ -154,11 +158,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.tasks_pb2.RuntimeMetadata """ - return _core_task.RuntimeMetadata( - type=self.type, - version=self.version, - flavor=self.flavor - ) + return _core_task.RuntimeMetadata(type=self.type, version=self.version, flavor=self.flavor) @classmethod def from_flyte_idl(cls, pb2_object): @@ -166,17 +166,13 @@ def from_flyte_idl(cls, pb2_object): :param flyteidl.core.tasks_pb2.RuntimeMetadata pb2_object: :rtype: RuntimeMetadata """ - return cls( - type=pb2_object.type, - version=pb2_object.version, - flavor=pb2_object.flavor - ) + return cls(type=pb2_object.type, version=pb2_object.version, flavor=pb2_object.flavor) class TaskMetadata(_common.FlyteIdlEntity): - - def __init__(self, discoverable, runtime, timeout, retries, interruptible, discovery_version, - deprecated_error_message): + def __init__( + self, discoverable, runtime, timeout, retries, interruptible, discovery_version, deprecated_error_message, + ): """ Information needed at runtime to determine behavior such as whether or not outputs are discoverable, timeouts, and retries. @@ -271,7 +267,7 @@ def to_flyte_idl(self): retries=self.retries.to_flyte_idl(), interruptible=self.interruptible, discovery_version=self.discovery_version, - deprecated_error_message=self.deprecated_error_message + deprecated_error_message=self.deprecated_error_message, ) tm.timeout.FromTimedelta(self.timeout) return tm @@ -289,12 +285,11 @@ def from_flyte_idl(cls, pb2_object): interruptible=pb2_object.interruptible if pb2_object.HasField("interruptible") else None, retries=_literals.RetryStrategy.from_flyte_idl(pb2_object.retries), discovery_version=pb2_object.discovery_version, - deprecated_error_message=pb2_object.deprecated_error_message + deprecated_error_message=pb2_object.deprecated_error_message, ) class TaskTemplate(_common.FlyteIdlEntity): - def __init__(self, id, type, metadata, interface, custom, container=None): """ A task template represents the full set of information necessary to perform a unit of work in the Flyte system. @@ -377,7 +372,7 @@ def to_flyte_idl(self): metadata=self.metadata.to_flyte_idl(), interface=self.interface.to_flyte_idl(), custom=_json_format.Parse(_json.dumps(self.custom), _struct.Struct()) if self.custom else None, - container=self.container.to_flyte_idl() if self.container else None + container=self.container.to_flyte_idl() if self.container else None, ) @classmethod @@ -392,12 +387,11 @@ def from_flyte_idl(cls, pb2_object): metadata=TaskMetadata.from_flyte_idl(pb2_object.metadata), interface=_interface.TypedInterface.from_flyte_idl(pb2_object.interface), custom=_json_format.MessageToDict(pb2_object.custom) if pb2_object else None, - container=Container.from_flyte_idl(pb2_object.container) if pb2_object.HasField("container") else None + container=Container.from_flyte_idl(pb2_object.container) if pb2_object.HasField("container") else None, ) class TaskSpec(_common.FlyteIdlEntity): - def __init__(self, template): """ :param TaskTemplate template: @@ -415,9 +409,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.tasks_pb2.TaskSpec """ - return _admin_task.TaskSpec( - template=self.template.to_flyte_idl() - ) + return _admin_task.TaskSpec(template=self.template.to_flyte_idl()) @classmethod def from_flyte_idl(cls, pb2_object): @@ -429,7 +421,6 @@ def from_flyte_idl(cls, pb2_object): class Task(_common.FlyteIdlEntity): - def __init__(self, id, closure): """ :param flytekit.models.core.identifier.Identifier id: The (project, domain, name) identifier for this task. @@ -458,10 +449,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.task_pb2.Task """ - return _admin_task.Task( - closure=self.closure.to_flyte_idl(), - id=self.id.to_flyte_idl(), - ) + return _admin_task.Task(closure=self.closure.to_flyte_idl(), id=self.id.to_flyte_idl(),) @classmethod def from_flyte_idl(cls, pb2_object): @@ -471,12 +459,11 @@ def from_flyte_idl(cls, pb2_object): """ return cls( closure=TaskClosure.from_flyte_idl(pb2_object.closure), - id=_identifier.Identifier.from_flyte_idl(pb2_object.id) + id=_identifier.Identifier.from_flyte_idl(pb2_object.id), ) class TaskClosure(_common.FlyteIdlEntity): - def __init__(self, compiled_task): """ :param CompiledTask compiled_task: @@ -494,9 +481,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.task_pb2.TaskClosure """ - return _admin_task.TaskClosure( - compiled_task=self.compiled_task.to_flyte_idl() - ) + return _admin_task.TaskClosure(compiled_task=self.compiled_task.to_flyte_idl()) @classmethod def from_flyte_idl(cls, pb2_object): @@ -504,13 +489,10 @@ def from_flyte_idl(cls, pb2_object): :param flyteidl.admin.task_pb2.TaskClosure pb2_object: :rtype: TaskClosure """ - return cls( - compiled_task=CompiledTask.from_flyte_idl(pb2_object.compiled_task) - ) + return cls(compiled_task=CompiledTask.from_flyte_idl(pb2_object.compiled_task)) class CompiledTask(_common.FlyteIdlEntity): - def __init__(self, template): """ :param TaskTemplate template: @@ -528,9 +510,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.compiler_pb2.CompiledTask """ - return _compiler.CompiledTask( - template=self.template.to_flyte_idl() - ) + return _compiler.CompiledTask(template=self.template.to_flyte_idl()) @classmethod def from_flyte_idl(cls, pb2_object): @@ -538,14 +518,13 @@ def from_flyte_idl(cls, pb2_object): :param flyteidl.core.compiler_pb2.CompiledTask pb2_object: :rtype: CompiledTask """ - return cls( - template=TaskTemplate.from_flyte_idl(pb2_object.template) - ) + return cls(template=TaskTemplate.from_flyte_idl(pb2_object.template)) class SparkJob(_common.FlyteIdlEntity): - - def __init__(self, spark_type, application_file, main_class, spark_conf, hadoop_conf, executor_path): + def __init__( + self, spark_type, application_file, main_class, spark_conf, hadoop_conf, executor_path, + ): """ This defines a SparkJob target. It will execute the appropriate SparkJob. @@ -662,6 +641,7 @@ class IOStrategy(_common.FlyteIdlEntity): """ Provides methods to manage data in and out of the Raw container using Download Modes. This can only be used if DataLoadingConfig is enabled. """ + DOWNLOAD_MODE_EAGER = _core_task.IOStrategy.DOWNLOAD_EAGER DOWNLOAD_MODE_STREAM = _core_task.IOStrategy.DOWNLOAD_STREAM DOWNLOAD_MODE_NO_DOWNLOAD = _core_task.IOStrategy.DO_NOT_DOWNLOAD @@ -670,26 +650,22 @@ class IOStrategy(_common.FlyteIdlEntity): UPLOAD_MODE_ON_EXIT = _core_task.IOStrategy.UPLOAD_ON_EXIT UPLOAD_MODE_NO_UPLOAD = _core_task.IOStrategy.DO_NOT_UPLOAD - def __init__(self, - download_mode: _core_task.IOStrategy.DownloadMode=DOWNLOAD_MODE_EAGER, - upload_mode: _core_task.IOStrategy.UploadMode=UPLOAD_MODE_ON_EXIT): + def __init__( + self, + download_mode: _core_task.IOStrategy.DownloadMode = DOWNLOAD_MODE_EAGER, + upload_mode: _core_task.IOStrategy.UploadMode = UPLOAD_MODE_ON_EXIT, + ): self._download_mode = download_mode self._upload_mode = upload_mode def to_flyte_idl(self) -> _core_task.IOStrategy: - return _core_task.IOStrategy( - download_mode=self._download_mode, - upload_mode=self._upload_mode - ) + return _core_task.IOStrategy(download_mode=self._download_mode, upload_mode=self._upload_mode) @classmethod def from_flyte_idl(cls, pb2_object: _core_task.IOStrategy): if pb2_object is None: return None - return cls( - download_mode=pb2_object.download_mode, - upload_mode=pb2_object.upload_mode, - ) + return cls(download_mode=pb2_object.download_mode, upload_mode=pb2_object.upload_mode,) class DataLoadingConfig(_common.FlyteIdlEntity): @@ -698,11 +674,18 @@ class DataLoadingConfig(_common.FlyteIdlEntity): LITERALMAP_FORMAT_YAML = _core_task.DataLoadingConfig.YAML _LITERALMAP_FORMATS = frozenset([LITERALMAP_FORMAT_JSON, LITERALMAP_FORMAT_PROTO, LITERALMAP_FORMAT_YAML]) - def __init__(self, input_path: str, output_path: str, enabled: bool = True, - format: _core_task.DataLoadingConfig.LiteralMapFormat = LITERALMAP_FORMAT_PROTO, io_strategy: IOStrategy=None): + def __init__( + self, + input_path: str, + output_path: str, + enabled: bool = True, + format: _core_task.DataLoadingConfig.LiteralMapFormat = LITERALMAP_FORMAT_PROTO, + io_strategy: IOStrategy = None, + ): if format not in self._LITERALMAP_FORMATS: raise ValueError( - "Metadata format {} not supported. Should be one of {}".format(format, self._LITERALMAP_FORMATS)) + "Metadata format {} not supported. Should be one of {}".format(format, self._LITERALMAP_FORMATS) + ) self._input_path = input_path self._output_path = output_path self._enabled = enabled @@ -733,7 +716,6 @@ def from_flyte_idl(cls, pb2: _core_task.DataLoadingConfig): class Container(_common.FlyteIdlEntity): - def __init__(self, image, command, args, resources, env, config, data_loading_config=None): """ This defines a container target. It will execute the appropriate command line on the appropriate image with @@ -840,12 +822,12 @@ def from_flyte_idl(cls, pb2_object): env={kv.key: kv.value for kv in pb2_object.env}, config={kv.key: kv.value for kv in pb2_object.config}, data_loading_config=DataLoadingConfig.from_flyte_idl(pb2_object.data_config) - if pb2_object.HasField("data_config") else None, + if pb2_object.HasField("data_config") + else None, ) class SidecarJob(_common.FlyteIdlEntity): - def __init__(self, pod_spec, primary_container_name): """ A sidecar job represents the full kubernetes pod spec and related metadata required for executing a sidecar @@ -876,8 +858,7 @@ def to_flyte_idl(self): :rtype: flyteidl.core.tasks_pb2.SidecarJob """ return _lazy_flyteidl.plugins.sidecar_pb2.SidecarJob( - pod_spec=self.pod_spec, - primary_container_name=self.primary_container_name + pod_spec=self.pod_spec, primary_container_name=self.primary_container_name ) @classmethod @@ -886,14 +867,10 @@ def from_flyte_idl(cls, pb2_object): :param flyteidl.admin.task_pb2.Task pb2_object: :rtype: Container """ - return cls( - pod_spec=pb2_object.pod_spec, - primary_container_name=pb2_object.primary_container_name, - ) + return cls(pod_spec=pb2_object.pod_spec, primary_container_name=pb2_object.primary_container_name,) class PyTorchJob(_common.FlyteIdlEntity): - def __init__(self, workers_count): self._workers_count = workers_count @@ -902,12 +879,8 @@ def workers_count(self): return self._workers_count def to_flyte_idl(self): - return _pytorch_task.DistributedPyTorchTrainingTask( - workers=self.workers_count, - ) + return _pytorch_task.DistributedPyTorchTrainingTask(workers=self.workers_count,) @classmethod def from_flyte_idl(cls, pb2_object): - return cls( - workers_count=pb2_object.workers, - ) + return cls(workers_count=pb2_object.workers,) diff --git a/flytekit/models/types.py b/flytekit/models/types.py index ab4945b956..1414c8ed1c 100644 --- a/flytekit/models/types.py +++ b/flytekit/models/types.py @@ -1,11 +1,13 @@ from __future__ import absolute_import +import json as _json + from flyteidl.core import types_pb2 as _types_pb2 +from google.protobuf import json_format as _json_format +from google.protobuf import struct_pb2 as _struct from flytekit.models import common as _common from flytekit.models.core import types as _core_types -from google.protobuf import json_format as _json_format, struct_pb2 as _struct -import json as _json class SimpleType(object): @@ -22,9 +24,7 @@ class SimpleType(object): class SchemaType(_common.FlyteIdlEntity): - class SchemaColumn(_common.FlyteIdlEntity): - class SchemaColumnType(object): INTEGER = _types_pb2.SchemaType.SchemaColumn.INTEGER FLOAT = _types_pb2.SchemaType.SchemaColumn.FLOAT @@ -61,10 +61,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.types_pb2.SchemaType.SchemaColumn """ - return _types_pb2.SchemaType.SchemaColumn( - name=self.name, - type=self.type - ) + return _types_pb2.SchemaType.SchemaColumn(name=self.name, type=self.type) @classmethod def from_flyte_idl(cls, proto): @@ -92,9 +89,7 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.types_pb2.SchemaType """ - return _types_pb2.SchemaType( - columns=[c.to_flyte_idl() for c in self.columns] - ) + return _types_pb2.SchemaType(columns=[c.to_flyte_idl() for c in self.columns]) @classmethod def from_flyte_idl(cls, proto): @@ -106,8 +101,9 @@ def from_flyte_idl(cls, proto): class LiteralType(_common.FlyteIdlEntity): - - def __init__(self, simple=None, schema=None, collection_type=None, map_value_type=None, blob=None, metadata=None): + def __init__( + self, simple=None, schema=None, collection_type=None, map_value_type=None, blob=None, metadata=None, + ): """ Only one of the kwargs may be set. :param int simple: Enum type from SimpleType @@ -185,7 +181,7 @@ def to_flyte_idl(self): collection_type=self.collection_type.to_flyte_idl() if self.collection_type is not None else None, map_value_type=self.map_value_type.to_flyte_idl() if self.map_value_type is not None else None, blob=self.blob.to_flyte_idl() if self.blob is not None else None, - metadata=metadata + metadata=metadata, ) return t @@ -207,7 +203,7 @@ def from_flyte_idl(cls, proto): collection_type=collection_type, map_value_type=map_value_type, blob=_core_types.BlobType.from_flyte_idl(proto.blob) if proto.HasField("blob") else None, - metadata=_json_format.MessageToDict(proto.metadata) or None + metadata=_json_format.MessageToDict(proto.metadata) or None, ) @@ -255,7 +251,4 @@ def from_flyte_idl(cls, pb2_object): :param flyteidl.core.types.OutputReference pb2_object: :rtype: OutputReference """ - return cls( - node_id=pb2_object.node_id, - var=pb2_object.var - ) + return cls(node_id=pb2_object.node_id, var=pb2_object.var) diff --git a/flytekit/models/workflow_closure.py b/flytekit/models/workflow_closure.py index 4504dd91ec..ffbfa7b3dc 100644 --- a/flytekit/models/workflow_closure.py +++ b/flytekit/models/workflow_closure.py @@ -1,13 +1,13 @@ from __future__ import absolute_import from flyteidl.core import workflow_closure_pb2 as _workflow_closure_pb2 + from flytekit.models import common as _common -from flytekit.models.core import workflow as _core_workflow_models from flytekit.models import task as _task_models +from flytekit.models.core import workflow as _core_workflow_models class WorkflowClosure(_common.FlyteIdlEntity): - def __init__(self, workflow, tasks=None): """ :param flytekit.models.core.workflow.WorkflowTemplate workflow: Workflow template @@ -35,8 +35,7 @@ def to_flyte_idl(self): :rtype: flyteidl.core.workflow_closure_pb2.WorkflowClosure """ return _workflow_closure_pb2.WorkflowClosure( - workflow=self.workflow.to_flyte_idl(), - tasks=[t.to_flyte_idl() for t in self.tasks] + workflow=self.workflow.to_flyte_idl(), tasks=[t.to_flyte_idl() for t in self.tasks], ) @classmethod diff --git a/flytekit/plugins/__init__.py b/flytekit/plugins/__init__.py index b626856a99..0a87cc37a7 100644 --- a/flytekit/plugins/__init__.py +++ b/flytekit/plugins/__init__.py @@ -1,62 +1,34 @@ from __future__ import absolute_import -from flytekit.tools import lazy_loader as _lazy_loader +from flytekit.tools import lazy_loader as _lazy_loader -pyspark = _lazy_loader.lazy_load_module("pyspark") # type: types.ModuleType +pyspark = _lazy_loader.lazy_load_module("pyspark") # type: _lazy_loader._LazyLoadModule -k8s = _lazy_loader.lazy_load_module("k8s") # type: types.ModuleType +k8s = _lazy_loader.lazy_load_module("k8s") # type: _lazy_loader._LazyLoadModule type(k8s).add_sub_module("io.api.core.v1.generated_pb2") type(k8s).add_sub_module("io.apimachinery.pkg.api.resource.generated_pb2") -flyteidl = _lazy_loader.lazy_load_module("flyteidl") # type: types.ModuleType +flyteidl = _lazy_loader.lazy_load_module("flyteidl") # type: _lazy_loader._LazyLoadModule type(flyteidl).add_sub_module("plugins.sidecar_pb2") -numpy = _lazy_loader.lazy_load_module("numpy") # type: types.ModuleType -pandas = _lazy_loader.lazy_load_module("pandas") # type: types.ModuleType +numpy = _lazy_loader.lazy_load_module("numpy") # type: _lazy_loader._LazyLoadModule +pandas = _lazy_loader.lazy_load_module("pandas") # type: _lazy_loader._LazyLoadModule -hmsclient = _lazy_loader.lazy_load_module("hmsclient") # type: types.ModuleType +hmsclient = _lazy_loader.lazy_load_module("hmsclient") # type: _lazy_loader._LazyLoadModule type(hmsclient).add_sub_module("genthrift.hive_metastore.ttypes") -torch = _lazy_loader.lazy_load_module("torch") # type: types.ModuleType +torch = _lazy_loader.lazy_load_module("torch") # type: _lazy_loader._LazyLoadModule -_lazy_loader.LazyLoadPlugin( - "spark", - ["pyspark>=2.4.0,<3.0.0"], - [pyspark] -) +_lazy_loader.LazyLoadPlugin("spark", ["pyspark>=2.4.0,<3.0.0"], [pyspark]) -_lazy_loader.LazyLoadPlugin( - "spark3", - ["pyspark>=3.0.0"], - [pyspark] -) +_lazy_loader.LazyLoadPlugin("spark3", ["pyspark>=3.0.0"], [pyspark]) -_lazy_loader.LazyLoadPlugin( - "sidecar", - ["k8s-proto>=0.0.3,<1.0.0"], - [k8s, flyteidl] -) +_lazy_loader.LazyLoadPlugin("sidecar", ["k8s-proto>=0.0.3,<1.0.0"], [k8s, flyteidl]) _lazy_loader.LazyLoadPlugin( - "schema", - [ - "numpy>=1.14.0,<2.0.0", - "pandas>=0.22.0,<2.0.0", - "pyarrow>=0.11.0,<1.0.0", - ], - [numpy, pandas] + "schema", ["numpy>=1.14.0,<2.0.0", "pandas>=0.22.0,<2.0.0", "pyarrow>=0.11.0,<1.0.0"], [numpy, pandas], ) -_lazy_loader.LazyLoadPlugin( - "hive_sensor", - [ - "hmsclient>=0.0.1,<1.0.0", - ], - [hmsclient] -) +_lazy_loader.LazyLoadPlugin("hive_sensor", ["hmsclient>=0.0.1,<1.0.0"], [hmsclient]) -_lazy_loader.LazyLoadPlugin( - "pytorch", - ["torch>=1.0.0,<2.0.0"], - [torch] -) +_lazy_loader.LazyLoadPlugin("pytorch", ["torch>=1.0.0,<2.0.0"], [torch]) diff --git a/flytekit/sdk/exceptions.py b/flytekit/sdk/exceptions.py index 1e9ea62b88..6753d4182e 100644 --- a/flytekit/sdk/exceptions.py +++ b/flytekit/sdk/exceptions.py @@ -1,4 +1,5 @@ from __future__ import absolute_import + from flytekit.common.exceptions import user as _user @@ -8,4 +9,5 @@ class RecoverableException(_user.FlyteRecoverableException): Any exception raised from user code other than RecoverableException will NOT be considered retryable and the task will fail without additional retries. """ + pass diff --git a/flytekit/sdk/spark_types.py b/flytekit/sdk/spark_types.py index ef5804ecde..9477895fac 100644 --- a/flytekit/sdk/spark_types.py +++ b/flytekit/sdk/spark_types.py @@ -5,4 +5,4 @@ class SparkType(enum.Enum): PYTHON = 1 SCALA = 2 JAVA = 3 - R = 4 \ No newline at end of file + R = 4 diff --git a/flytekit/sdk/tasks.py b/flytekit/sdk/tasks.py index ce37a69b18..00d06135d1 100644 --- a/flytekit/sdk/tasks.py +++ b/flytekit/sdk/tasks.py @@ -6,9 +6,13 @@ from flytekit.common import constants as _common_constants from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.tasks import sdk_runnable as _sdk_runnable_tasks, sdk_dynamic as _sdk_dynamic, \ - spark_task as _sdk_spark_tasks, generic_spark_task as _sdk_generic_spark_task, hive_task as _sdk_hive_tasks, \ - sidecar_task as _sdk_sidecar_tasks, pytorch_task as _sdk_pytorch_tasks +from flytekit.common.tasks import generic_spark_task as _sdk_generic_spark_task +from flytekit.common.tasks import hive_task as _sdk_hive_tasks +from flytekit.common.tasks import pytorch_task as _sdk_pytorch_tasks +from flytekit.common.tasks import sdk_dynamic as _sdk_dynamic +from flytekit.common.tasks import sdk_runnable as _sdk_runnable_tasks +from flytekit.common.tasks import sidecar_task as _sdk_sidecar_tasks +from flytekit.common.tasks import spark_task as _sdk_spark_tasks from flytekit.common.tasks import task as _task from flytekit.common.types import helpers as _type_helpers from flytekit.contrib.notebook import tasks as _nb_tasks @@ -41,20 +45,18 @@ def my_task(wf_params, in1, in2, out1, out2): def apply_inputs_wrapper(task): if not isinstance(task, _task.SdkTask): - additional_msg = \ - "Inputs can only be applied to a task. Did you forget the task decorator on method '{}.{}'?".format( - task.__module__, - task.__name__ if hasattr(task, "__name__") else "" - ) + additional_msg = "Inputs can only be applied to a task. Did you forget the task decorator on method '{}.{}'?".format( + task.__module__, task.__name__ if hasattr(task, "__name__") else "", + ) raise _user_exceptions.FlyteTypeException( expected_type=_sdk_runnable_tasks.SdkRunnableTask, received_type=type(task), received_value=task, - additional_msg=additional_msg) + additional_msg=additional_msg, + ) for k, v in _six.iteritems(kwargs): kwargs[k] = _interface_model.Variable( - _type_helpers.python_std_to_sdk_type(v).to_flyte_literal_type(), - '' + _type_helpers.python_std_to_sdk_type(v).to_flyte_literal_type(), "" ) # TODO: Support descriptions task.add_inputs(kwargs) @@ -88,22 +90,23 @@ def my_task(wf_params, out1, out2): name and type. :rtype: flytekit.common.tasks.sdk_runnable.SdkRunnableTask """ + def apply_outputs_wrapper(task): - if not isinstance(task, _sdk_runnable_tasks.SdkRunnableTask) and not isinstance(task, _nb_tasks.SdkNotebookTask): - additional_msg = \ - "Outputs can only be applied to a task. Did you forget the task decorator on method '{}.{}'?".format( - task.__module__, - task.__name__ if hasattr(task, "__name__") else "" - ) + if not isinstance(task, _sdk_runnable_tasks.SdkRunnableTask) and not isinstance( + task, _nb_tasks.SdkNotebookTask + ): + additional_msg = "Outputs can only be applied to a task. Did you forget the task decorator on method '{}.{}'?".format( + task.__module__, task.__name__ if hasattr(task, "__name__") else "", + ) raise _user_exceptions.FlyteTypeException( expected_type=_sdk_runnable_tasks.SdkRunnableTask, received_type=type(task), received_value=task, - additional_msg=additional_msg) + additional_msg=additional_msg, + ) for k, v in _six.iteritems(kwargs): kwargs[k] = _interface_model.Variable( - _type_helpers.python_std_to_sdk_type(v).to_flyte_literal_type(), - '' + _type_helpers.python_std_to_sdk_type(v).to_flyte_literal_type(), "" ) # TODO: Support descriptions task.add_outputs(kwargs) @@ -116,23 +119,23 @@ def apply_outputs_wrapper(task): def python_task( - _task_function=None, - cache_version='', - retries=0, - interruptible=None, - deprecated='', - storage_request=None, - cpu_request=None, - gpu_request=None, - memory_request=None, - storage_limit=None, - cpu_limit=None, - gpu_limit=None, - memory_limit=None, - cache=False, - timeout=None, - environment=None, - cls=None, + _task_function=None, + cache_version="", + retries=0, + interruptible=None, + deprecated="", + storage_request=None, + cpu_request=None, + gpu_request=None, + memory_request=None, + storage_limit=None, + cpu_limit=None, + gpu_limit=None, + memory_limit=None, + cache=False, + timeout=None, + environment=None, + cls=None, ): """ Decorator to create a Python Task definition. This task will run as a single unit of work on the platform. @@ -225,6 +228,7 @@ def my_task(wf_params, int_list, sum_of_list): :rtype: flytekit.common.tasks.sdk_runnable.SdkRunnableTask """ + def wrapper(fn): return (cls or _sdk_runnable_tasks.SdkRunnableTask)( task_function=fn, @@ -244,7 +248,8 @@ def wrapper(fn): discoverable=cache, timeout=timeout or _datetime.timedelta(seconds=0), environment=environment, - custom={}) + custom={}, + ) if _task_function: return wrapper(_task_function) @@ -253,25 +258,25 @@ def wrapper(fn): def dynamic_task( - _task_function=None, - cache_version='', - retries=0, - interruptible=None, - deprecated='', - storage_request=None, - cpu_request=None, - gpu_request=None, - memory_request=None, - storage_limit=None, - cpu_limit=None, - gpu_limit=None, - memory_limit=None, - cache=False, - timeout=None, - allowed_failure_ratio=None, - max_concurrency=None, - environment=None, - cls=None + _task_function=None, + cache_version="", + retries=0, + interruptible=None, + deprecated="", + storage_request=None, + cpu_request=None, + gpu_request=None, + memory_request=None, + storage_limit=None, + cpu_limit=None, + gpu_limit=None, + memory_limit=None, + cache=False, + timeout=None, + allowed_failure_ratio=None, + max_concurrency=None, + environment=None, + cls=None, ): """ Decorator to create a custom dynamic task definition. Dynamic tasks should be used to split up work into @@ -406,17 +411,17 @@ def wrapper(fn): def spark_task( - _task_function=None, - cache_version='', - retries=0, - interruptible=None, - deprecated='', - cache=False, - timeout=None, - spark_conf=None, - hadoop_conf=None, - environment=None, - cls=None + _task_function=None, + cache_version="", + retries=0, + interruptible=None, + deprecated="", + cache=False, + timeout=None, + spark_conf=None, + hadoop_conf=None, + environment=None, + cls=None, ): """ Decorator to create a spark task. This task will connect to a Spark cluster, configure the environment, @@ -478,7 +483,7 @@ def wrapper(fn): discovery_version=cache_version, retries=retries, interruptible=interruptible, - spark_type= _spark_type.PYTHON, + spark_type=_spark_type.PYTHON, deprecated=deprecated, discoverable=cache, timeout=timeout or _datetime.timedelta(seconds=0), @@ -494,19 +499,19 @@ def wrapper(fn): def generic_spark_task( - spark_type, - main_class, - main_application_file, - cache_version='', - retries=0, - interruptible=None, - inputs=None, - deprecated='', - cache=False, - timeout=None, - spark_conf=None, - hadoop_conf=None, - environment=None, + spark_type, + main_class, + main_application_file, + cache_version="", + retries=0, + interruptible=None, + inputs=None, + deprecated="", + cache=False, + timeout=None, + spark_conf=None, + hadoop_conf=None, + environment=None, ): """ Create a generic spark task. This task will connect to a Spark cluster, configure the environment, @@ -515,21 +520,21 @@ def generic_spark_task( """ return _sdk_generic_spark_task.SdkGenericSparkTask( - task_type=_common_constants.SdkTaskType.SPARK_TASK, - discovery_version=cache_version, - retries=retries, - interruptible=interruptible, - deprecated=deprecated, - discoverable=cache, - timeout=timeout or _datetime.timedelta(seconds=0), - spark_type = spark_type, - task_inputs= inputs, - main_class = main_class or "", - main_application_file = main_application_file or "", - spark_conf=spark_conf or {}, - hadoop_conf=hadoop_conf or {}, - environment=environment or {}, - ) + task_type=_common_constants.SdkTaskType.SPARK_TASK, + discovery_version=cache_version, + retries=retries, + interruptible=interruptible, + deprecated=deprecated, + discoverable=cache, + timeout=timeout or _datetime.timedelta(seconds=0), + spark_type=spark_type, + task_inputs=inputs, + main_class=main_class or "", + main_application_file=main_application_file or "", + spark_conf=spark_conf or {}, + hadoop_conf=hadoop_conf or {}, + environment=environment or {}, + ) def qubole_spark_task(*args, **kwargs): @@ -540,24 +545,24 @@ def qubole_spark_task(*args, **kwargs): def hive_task( - _task_function=None, - cache_version='', - retries=0, - interruptible=None, - deprecated='', - storage_request=None, - cpu_request=None, - gpu_request=None, - memory_request=None, - storage_limit=None, - cpu_limit=None, - gpu_limit=None, - memory_limit=None, - cache=False, - timeout=None, - environment=None, - cls=None - ): + _task_function=None, + cache_version="", + retries=0, + interruptible=None, + deprecated="", + storage_request=None, + cpu_request=None, + gpu_request=None, + memory_request=None, + storage_limit=None, + cpu_limit=None, + gpu_limit=None, + memory_limit=None, + cache=False, + timeout=None, + environment=None, + cls=None, +): """ Decorator to create a hive task. This task should output a list of hive queries which are run on a hive cluster. @@ -648,6 +653,7 @@ def test_hive(wf_params, a): :rtype: flytekit.common.tasks.sdk_runnable.SdkHiveTask """ + def wrapper(fn): return (cls or _sdk_hive_tasks.SdkHiveTask)( @@ -667,7 +673,7 @@ def wrapper(fn): memory_limit=memory_limit, discoverable=cache, timeout=timeout or _datetime.timedelta(seconds=0), - cluster_label='', + cluster_label="", tags=[], environment=environment or {}, ) @@ -679,26 +685,26 @@ def wrapper(fn): def qubole_hive_task( - _task_function=None, - cache_version='', - retries=0, - interruptible=None, - deprecated='', - storage_request=None, - cpu_request=None, - gpu_request=None, - memory_request=None, - storage_limit=None, - cpu_limit=None, - gpu_limit=None, - memory_limit=None, - cache=False, - timeout=None, - cluster_label=None, - tags=None, - environment=None, - cls=None - ): + _task_function=None, + cache_version="", + retries=0, + interruptible=None, + deprecated="", + storage_request=None, + cpu_request=None, + gpu_request=None, + memory_request=None, + storage_limit=None, + cpu_limit=None, + gpu_limit=None, + memory_limit=None, + cache=False, + timeout=None, + cluster_label=None, + tags=None, + environment=None, + cls=None, +): """ Decorator to create a qubole hive task. This is hive task runs on a qubole cluster, and therefore allows users to pass cluster labels and qubole query tags. Similar to hive task, this task should output a list of hive queries @@ -792,6 +798,7 @@ def test_hive(wf_params, a): :rtype: flytekit.common.tasks.sdk_runnable.SdkHiveTask """ + def wrapper(fn): return (cls or _sdk_hive_tasks.SdkHiveTask)( @@ -811,7 +818,7 @@ def wrapper(fn): memory_limit=memory_limit, discoverable=cache, timeout=timeout or _datetime.timedelta(seconds=0), - cluster_label=cluster_label or '', + cluster_label=cluster_label or "", tags=tags or [], environment=environment or {}, ) @@ -825,25 +832,25 @@ def wrapper(fn): def sidecar_task( - _task_function=None, - cache_version='', - retries=0, - interruptible=None, - deprecated='', - storage_request=None, - cpu_request=None, - gpu_request=None, - memory_request=None, - storage_limit=None, - cpu_limit=None, - gpu_limit=None, - memory_limit=None, - cache=False, - timeout=None, - environment=None, - pod_spec=None, - primary_container_name=None, - cls=None, + _task_function=None, + cache_version="", + retries=0, + interruptible=None, + deprecated="", + storage_request=None, + cpu_request=None, + gpu_request=None, + memory_request=None, + storage_limit=None, + cpu_limit=None, + gpu_limit=None, + memory_limit=None, + cache=False, + timeout=None, + environment=None, + pod_spec=None, + primary_container_name=None, + cls=None, ): """ Decorator to create a Sidecar Task definition. This task will execute the primary task alongside the specified @@ -975,6 +982,7 @@ def a_sidecar_task(wfparams): :rtype: flytekit.common.tasks.sdk_runnable.SdkRunnableTask """ + def wrapper(fn): return (cls or _sdk_sidecar_tasks.SdkSidecarTask)( @@ -1006,27 +1014,27 @@ def wrapper(fn): def dynamic_sidecar_task( - _task_function=None, - cache_version="", - retries=0, - interruptible=None, - deprecated="", - storage_request=None, - cpu_request=None, - gpu_request=None, - memory_request=None, - storage_limit=None, - cpu_limit=None, - gpu_limit=None, - memory_limit=None, - cache=False, - timeout=None, - allowed_failure_ratio=None, - max_concurrency=None, - environment=None, - pod_spec=None, - primary_container_name=None, - cls=None, + _task_function=None, + cache_version="", + retries=0, + interruptible=None, + deprecated="", + storage_request=None, + cpu_request=None, + gpu_request=None, + memory_request=None, + storage_limit=None, + cpu_limit=None, + gpu_limit=None, + memory_limit=None, + cache=False, + timeout=None, + allowed_failure_ratio=None, + max_concurrency=None, + environment=None, + pod_spec=None, + primary_container_name=None, + cls=None, ): """ Decorator to create a custom dynamic sidecar task definition. Dynamic @@ -1167,6 +1175,7 @@ def my_task(wf_params, out): logic into the base Flyte programming model. :rtype: flytekit.common.tasks.sidecar_Task.SdkDynamicSidecarTask """ + def wrapper(fn): return (cls or _sdk_sidecar_tasks.SdkDynamicSidecarTask)( task_function=fn, @@ -1199,24 +1208,24 @@ def wrapper(fn): def pytorch_task( - _task_function=None, - cache_version='', - retries=0, - interruptible=False, - deprecated='', - cache=False, - timeout=None, - workers_count=1, - per_replica_storage_request="", - per_replica_cpu_request="", - per_replica_gpu_request="", - per_replica_memory_request="", - per_replica_storage_limit="", - per_replica_cpu_limit="", - per_replica_gpu_limit="", - per_replica_memory_limit="", - environment=None, - cls=None + _task_function=None, + cache_version="", + retries=0, + interruptible=False, + deprecated="", + cache=False, + timeout=None, + workers_count=1, + per_replica_storage_request="", + per_replica_cpu_request="", + per_replica_gpu_request="", + per_replica_memory_request="", + per_replica_storage_limit="", + per_replica_cpu_limit="", + per_replica_gpu_limit="", + per_replica_memory_limit="", + environment=None, + cls=None, ): """ Decorator to create a Pytorch Task definition. This task will submit PyTorchJob (see https://github.com/kubeflow/pytorch-operator) @@ -1324,6 +1333,7 @@ def my_pytorch_job(wf_params, int_list, result): :rtype: flytekit.common.tasks.sdk_runnable.SdkRunnableTask """ + def wrapper(fn): return (cls or _sdk_pytorch_tasks.SdkPyTorchTask)( task_function=fn, @@ -1343,7 +1353,7 @@ def wrapper(fn): per_replica_cpu_limit=per_replica_cpu_limit, per_replica_gpu_limit=per_replica_gpu_limit, per_replica_memory_limit=per_replica_memory_limit, - environment=environment or {} + environment=environment or {}, ) if _task_function: diff --git a/flytekit/sdk/test_utils.py b/flytekit/sdk/test_utils.py index f410e0959e..c31a6c3933 100644 --- a/flytekit/sdk/test_utils.py +++ b/flytekit/sdk/test_utils.py @@ -1,7 +1,9 @@ from __future__ import absolute_import + +from wrapt import decorator as _decorator + from flytekit.common import utils as _utils from flytekit.interfaces.data import data_proxy as _data_proxy -from wrapt import decorator as _decorator class LocalTestFileSystem(object): diff --git a/flytekit/sdk/types.py b/flytekit/sdk/types.py index dcb2cdeaec..bcc158bd90 100644 --- a/flytekit/sdk/types.py +++ b/flytekit/sdk/types.py @@ -1,6 +1,11 @@ from __future__ import absolute_import -from flytekit.common.types import primitives as _primitives, blobs as _blobs, schema as _schema, helpers as _helpers, \ - proto as _proto, containers as _containers + +from flytekit.common.types import blobs as _blobs +from flytekit.common.types import containers as _containers +from flytekit.common.types import helpers as _helpers +from flytekit.common.types import primitives as _primitives +from flytekit.common.types import proto as _proto +from flytekit.common.types import schema as _schema class Types(object): diff --git a/flytekit/sdk/workflow.py b/flytekit/sdk/workflow.py index a6abe2bb26..5fc1539ee7 100644 --- a/flytekit/sdk/workflow.py +++ b/flytekit/sdk/workflow.py @@ -2,7 +2,8 @@ import six as _six -from flytekit.common import workflow as _common_workflow, promise as _promise +from flytekit.common import promise as _promise +from flytekit.common import workflow as _common_workflow from flytekit.common.types import helpers as _type_helpers @@ -20,7 +21,7 @@ def __init__(self, sdk_type, help=None, **kwargs): :param bool required: If set, default must be None :param T default: If this is not a required input, the value will default to this value. Specify as a kwarg. """ - super(Input, self).__init__('', _type_helpers.python_std_to_sdk_type(sdk_type), help=help, **kwargs) + super(Input, self).__init__("", _type_helpers.python_std_to_sdk_type(sdk_type), help=help, **kwargs) class Output(_common_workflow.Output): @@ -37,10 +38,7 @@ def __init__(self, value, sdk_type=None, help=None): this value be provided as the SDK might not always be able to infer the correct type. """ super(Output, self).__init__( - '', - value, - sdk_type=_type_helpers.python_std_to_sdk_type(sdk_type) if sdk_type else None, - help=help + "", value, sdk_type=_type_helpers.python_std_to_sdk_type(sdk_type) if sdk_type else None, help=help, ) @@ -117,5 +115,6 @@ def workflow(nodes, inputs=None, outputs=None, cls=None, on_failure=None): inputs=[v.rename_and_return_reference(k) for k, v in sorted(_six.iteritems(inputs or {}))], outputs=[v.rename_and_return_reference(k) for k, v in sorted(_six.iteritems(outputs or {}))], nodes=[v.assign_id_and_return(k) for k, v in sorted(_six.iteritems(nodes))], - metadata=_common_workflow._workflow_models.WorkflowMetadata(on_failure=on_failure)) + metadata=_common_workflow._workflow_models.WorkflowMetadata(on_failure=on_failure), + ) return wf diff --git a/flytekit/tools/lazy_loader.py b/flytekit/tools/lazy_loader.py index cbf9367f79..7c546f7241 100644 --- a/flytekit/tools/lazy_loader.py +++ b/flytekit/tools/lazy_loader.py @@ -1,4 +1,5 @@ from __future__ import absolute_import + import importlib as _importlib import sys as _sys import types as _types @@ -27,18 +28,19 @@ def get_extras_require(cls): all_plugins = [] for k in d: # Default to Spark 2.4.x . - if k !="spark3": + if k != "spark3": all_plugins.extend(d[k]) - d['all'] = all_plugins + d["all"] = all_plugins return d def lazy_load_module(module): """ - :param Text module: + :param Text module: :rtype: _types.ModuleType """ + class LazyLoadModule(_LazyLoadModule): _module = module _lazy_submodules = dict() @@ -49,11 +51,13 @@ class LazyLoadModule(_LazyLoadModule): class _LazyLoadModule(_types.ModuleType): - _ERROR_MSG_FMT = "Attempting to use a plugin functionality that requires module " \ - "`{module}`, but it couldn't be loaded. Please pip install at least one of {plugins} or " \ - "`flytekit[all]` to get these dependencies.\n" \ - "\n" \ - "Original message: {msg}" + _ERROR_MSG_FMT = ( + "Attempting to use a plugin functionality that requires module " + "`{module}`, but it couldn't be loaded. Please pip install at least one of {plugins} or " + "`flytekit[all]` to get these dependencies.\n" + "\n" + "Original message: {msg}" + ) @classmethod def _load(cls): @@ -62,13 +66,7 @@ def _load(cls): try: module = _importlib.import_module(cls._module) except ImportError as e: - raise ImportError( - cls._ERROR_MSG_FMT.format( - module=cls._module, - plugins=cls._plugins, - msg=e - ) - ) + raise ImportError(cls._ERROR_MSG_FMT.format(module=cls._module, plugins=cls._plugins, msg=e)) return module def __getattribute__(self, item): diff --git a/flytekit/tools/module_loader.py b/flytekit/tools/module_loader.py index f928febc3a..59e8d0ce1a 100644 --- a/flytekit/tools/module_loader.py +++ b/flytekit/tools/module_loader.py @@ -14,7 +14,7 @@ def iterate_modules(pkgs): for package_name in pkgs: package = importlib.import_module(package_name) yield package - for _, name, _ in pkgutil.walk_packages(package.__path__, prefix='{}.'.format(package_name)): + for _, name, _ in pkgutil.walk_packages(package.__path__, prefix="{}.".format(package_name)): yield importlib.import_module(name) @@ -31,40 +31,36 @@ def load_workflow_modules(pkgs): def _topo_sort_helper( - obj, - entity_to_module_key, - visited, - recursion_set, - recursion_stack, - include_entities, - ignore_entities, - detect_unreferenced_entities): + obj, + entity_to_module_key, + visited, + recursion_set, + recursion_stack, + include_entities, + ignore_entities, + detect_unreferenced_entities, +): visited.add(obj) recursion_stack.append(obj) if obj in recursion_set: raise _user_exceptions.FlyteAssertion( "A cyclical dependency was detected during topological sort of entities. " - "Cycle path was:\n\n\t{}".format( - "\n\t".join( - p for p in recursion_stack[recursion_set[obj]:] - ) - ) + "Cycle path was:\n\n\t{}".format("\n\t".join(p for p in recursion_stack[recursion_set[obj] :])) ) recursion_set[obj] = len(recursion_stack) - 1 for upstream in obj.upstream_entities: if upstream not in visited: - for m1, k1, o1 in \ - _topo_sort_helper( - upstream, - entity_to_module_key, - visited, - recursion_set, - recursion_stack, - include_entities, - ignore_entities, - detect_unreferenced_entities - ): + for m1, k1, o1 in _topo_sort_helper( + upstream, + entity_to_module_key, + visited, + recursion_set, + recursion_stack, + include_entities, + ignore_entities, + detect_unreferenced_entities, + ): yield m1, k1, o1 recursion_stack.pop() @@ -83,10 +79,7 @@ def _topo_sort_helper( def iterate_registerable_entities_in_order( - pkgs, - ignore_entities=None, - include_entities=None, - detect_unreferenced_entities=True + pkgs, ignore_entities=None, include_entities=None, detect_unreferenced_entities=True ): """ This function will iterate all discovered entities in the given package list. It will then attempt to @@ -127,15 +120,14 @@ def iterate_registerable_entities_in_order( if o not in visited: recursion_set = dict() recursion_stack = [] - for m, k, o2 in \ - _topo_sort_helper( - o, - entity_to_module_key, - visited, - recursion_set, - recursion_stack, - include_entities, - ignore_entities, - detect_unreferenced_entities=detect_unreferenced_entities - ): + for m, k, o2 in _topo_sort_helper( + o, + entity_to_module_key, + visited, + recursion_set, + recursion_stack, + include_entities, + ignore_entities, + detect_unreferenced_entities=detect_unreferenced_entities, + ): yield m, k, o2 diff --git a/flytekit/tools/subprocess.py b/flytekit/tools/subprocess.py index 8753e883ae..a437b61a1a 100644 --- a/flytekit/tools/subprocess.py +++ b/flytekit/tools/subprocess.py @@ -1,5 +1,4 @@ -from __future__ import absolute_import -from __future__ import print_function +from __future__ import absolute_import, print_function import logging import shlex as _schlex @@ -14,12 +13,7 @@ def check_call(cmd_args, **kwargs): # Jupyter notebooks hijack I/O and thus we cannot dump directly to stdout. with _tempfile.TemporaryFile() as std_out: with _tempfile.TemporaryFile() as std_err: - ret_code = _subprocess.Popen( - cmd_args, - stdout=std_out, - stderr=std_err, - **kwargs - ).wait() + ret_code = _subprocess.Popen(cmd_args, stdout=std_out, stderr=std_err, **kwargs).wait() # Dump sub-process' std out into current std out std_out.seek(0) diff --git a/flytekit/type_engines/common.py b/flytekit/type_engines/common.py index 8d08958c7a..5ab9dfca0e 100644 --- a/flytekit/type_engines/common.py +++ b/flytekit/type_engines/common.py @@ -1,10 +1,11 @@ from __future__ import absolute_import + import abc as _abc + import six as _six class TypeEngine(_six.with_metaclass(_abc.ABCMeta, object)): - @_abc.abstractmethod def python_std_to_sdk_type(self, t): """ diff --git a/flytekit/type_engines/default/flyte.py b/flytekit/type_engines/default/flyte.py index 3e992f4ca9..b7c329d8a3 100644 --- a/flytekit/type_engines/default/flyte.py +++ b/flytekit/type_engines/default/flyte.py @@ -1,13 +1,18 @@ from __future__ import absolute_import -from flytekit.common.exceptions import system as _system_exceptions, user as _user_exceptions -from flytekit.common.types import primitives as _primitive_types, base_sdk_types as _base_sdk_types, containers as \ - _container_types, schema as _schema, blobs as _blobs, proto as _proto -from flytekit.models import types as _literal_type_models -from flytekit.models.core import types as _core_types import importlib as _importer +from flytekit.common.exceptions import system as _system_exceptions +from flytekit.common.exceptions import user as _user_exceptions +from flytekit.common.types import base_sdk_types as _base_sdk_types +from flytekit.common.types import blobs as _blobs +from flytekit.common.types import containers as _container_types from flytekit.common.types import helpers as _helpers +from flytekit.common.types import primitives as _primitive_types +from flytekit.common.types import proto as _proto +from flytekit.common.types import schema as _schema +from flytekit.models import types as _literal_type_models +from flytekit.models.core import types as _core_types def _proto_sdk_type_from_tag(tag): @@ -15,26 +20,21 @@ def _proto_sdk_type_from_tag(tag): :param Text tag: :rtype: _proto.Protobuf """ - if '.' not in tag: + if "." not in tag: raise _user_exceptions.FlyteValueException( - tag, - "Protobuf tag must include at least one '.' to delineate package and object name." + tag, "Protobuf tag must include at least one '.' to delineate package and object name.", ) - module, name = tag.rsplit('.', 1) + module, name = tag.rsplit(".", 1) try: pb_module = _importer.import_module(module) except ImportError: raise _user_exceptions.FlyteAssertion( - "Could not resolve the protobuf definition @ {}. Is the protobuf library installed?".format( - module - ) + "Could not resolve the protobuf definition @ {}. Is the protobuf library installed?".format(module) ) if not hasattr(pb_module, name): - raise _user_exceptions.FlyteAssertion( - "Could not find the protobuf named: {} @ {}.".format(name, module) - ) + raise _user_exceptions.FlyteAssertion("Could not find the protobuf named: {} @ {}.".format(name, module)) return _proto.create_protobuf(getattr(pb_module, name)) @@ -62,7 +62,8 @@ def python_std_to_sdk_type(self, t): if len(t) != 1: raise _user_exceptions.FlyteAssertion( "When specifying a list type, there must be exactly one element in " - "the list describing the contained type.") + "the list describing the contained type." + ) return _container_types.List(_helpers.python_std_to_sdk_type(t[0])) elif isinstance(t, dict): raise _user_exceptions.FlyteAssertion("Map types are not yet implemented.") @@ -73,8 +74,8 @@ def python_std_to_sdk_type(self, t): type(t), _base_sdk_types.FlyteSdkType, additional_msg="Should be of form similar to: Types.Integer, [Types.Integer], {Types.String: " - "Types.Integer}", - received_value=t + "Types.Integer}", + received_value=t, ) def get_sdk_type_from_literal_type(self, literal_type): @@ -91,8 +92,10 @@ def get_sdk_type_from_literal_type(self, literal_type): elif literal_type.blob is not None: return self._get_blob_impl_from_type(literal_type.blob) elif literal_type.simple is not None: - if literal_type.simple == _literal_type_models.SimpleType.BINARY and _proto.Protobuf.PB_FIELD_KEY in \ - literal_type.metadata: + if ( + literal_type.simple == _literal_type_models.SimpleType.BINARY + and _proto.Protobuf.PB_FIELD_KEY in literal_type.metadata + ): return _proto_sdk_type_from_tag(literal_type.metadata[_proto.Protobuf.PB_FIELD_KEY]) sdk_type = self._SIMPLE_TYPE_LOOKUP_TABLE.get(literal_type.simple) if sdk_type is None: @@ -129,7 +132,7 @@ def infer_sdk_type_from_literal(self, literal): # noqa sdk_type = _primitive_types.Generic elif literal.scalar.binary is not None: if literal.scalar.binary.tag.startswith(_proto.Protobuf.TAG_PREFIX): - sdk_type = _proto_sdk_type_from_tag(literal.scalar.binary.tag[len(_proto.Protobuf.TAG_PREFIX):]) + sdk_type = _proto_sdk_type_from_tag(literal.scalar.binary.tag[len(_proto.Protobuf.TAG_PREFIX) :]) else: raise NotImplementedError("TODO: Binary is only supported for protobuf types currently") elif literal.scalar.primitive.boolean is not None: diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000..55ec8d784c --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,2 @@ +[tool.black] +line-length = 120 diff --git a/requirements.in b/requirements.in new file mode 100644 index 0000000000..5b206568a6 --- /dev/null +++ b/requirements.in @@ -0,0 +1,2 @@ +.[all] +-e file:.#egg=flytekit diff --git a/requirements.txt b/requirements.txt index 8fc7f72648..d15cee8c90 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,75 @@ -pytest==4.6.6 -mock==3.0.5 -six==1.12.0 \ No newline at end of file +# +# This file is autogenerated by pip-compile +# To update, run: +# +# make requirements.txt +# +-e file:.#egg=flytekit # via -r requirements.in +ansiwrap==0.8.4 # via papermill +appdirs==1.4.4 # via black +async-generator==1.10 # via nbclient +attrs==19.3.0 # via black, jsonschema +black==19.10b0 # via papermill +boto3==1.14.41 # via flytekit +botocore==1.17.41 # via boto3, s3transfer +certifi==2020.6.20 # via requests +chardet==3.0.4 # via requests +click==7.1.2 # via black, flytekit, hmsclient, papermill +croniter==0.3.34 # via flytekit +decorator==4.4.2 # via traitlets +deprecated==1.2.10 # via flytekit +docutils==0.15.2 # via botocore +entrypoints==0.3 # via papermill +flyteidl==0.18.2 # via flytekit +future==0.18.2 # via torch +grpcio==1.31.0 # via flytekit +hmsclient==0.1.1 # via flytekit +idna==2.10 # via requests +importlib-metadata==1.7.0 # via jsonschema, keyring +ipython-genutils==0.2.0 # via nbformat, traitlets +jmespath==0.10.0 # via boto3, botocore +jsonschema==3.2.0 # via nbformat +jupyter-client==6.1.6 # via nbclient, papermill +jupyter-core==4.6.3 # via jupyter-client, nbformat +k8s-proto==0.0.3 # via flytekit +keyring==21.3.0 # via flytekit +natsort==7.0.1 # via croniter +nbclient==0.4.1 # via papermill +nbformat==5.0.7 # via nbclient, papermill +nest-asyncio==1.4.0 # via nbclient +numpy==1.19.1 # via flytekit, pandas, pyarrow, torch +pandas==1.1.0 # via flytekit +papermill==2.1.2 # via flytekit +pathspec==0.8.0 # via black +protobuf==3.12.4 # via flyteidl, flytekit, k8s-proto +py4j==0.10.7 # via pyspark +pyarrow==0.17.1 # via flytekit +pyrsistent==0.16.0 # via jsonschema +pyspark==2.4.6 # via flytekit +python-dateutil==2.8.1 # via botocore, croniter, flytekit, jupyter-client, pandas +pytimeparse==1.1.8 # via flytekit +pytz==2018.4 # via flytekit, pandas +pyyaml==5.3.1 # via papermill +pyzmq==19.0.2 # via jupyter-client +regex==2020.7.14 # via black +requests==2.24.0 # via flytekit, papermill, responses +responses==0.10.16 # via flytekit +s3transfer==0.3.3 # via boto3 +six==1.15.0 # via flytekit, grpcio, jsonschema, protobuf, python-dateutil, responses, tenacity, thrift, traitlets +sortedcontainers==2.2.2 # via flytekit +statsd==3.3.0 # via flytekit +tenacity==6.2.0 # via papermill +textwrap3==0.9.2 # via ansiwrap +thrift==0.13.0 # via hmsclient +toml==0.10.1 # via black +torch==1.6.0 # via flytekit +tornado==6.0.4 # via jupyter-client +tqdm==4.48.2 # via papermill +traitlets==4.3.3 # via jupyter-client, jupyter-core, nbclient, nbformat +typed-ast==1.4.1 # via black +urllib3==1.25.10 # via botocore, flytekit, requests, responses +wrapt==1.12.1 # via deprecated, flytekit +zipp==3.1.0 # via importlib-metadata + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/sample-notebooks/image.py b/sample-notebooks/image.py index d21323349c..72d093909e 100644 --- a/sample-notebooks/image.py +++ b/sample-notebooks/image.py @@ -1,6 +1,8 @@ -import cv2 import sys +import cv2 + + def filter_edges(input_image_path: str, output_image_path: str): print("Reading {}".format(input_image_path)) img = cv2.imread(input_image_path, 0) @@ -10,6 +12,7 @@ def filter_edges(input_image_path: str, output_image_path: str): cv2.imwrite(output_image_path, edges) return output_image_path + if __name__ == "__main__": inp = sys.argv[1] out = sys.argv[2] diff --git a/setup.cfg b/setup.cfg index b7e58d5f7b..c19835b199 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,14 @@ +[isort] +multi_line_output = 3 +include_trailing_comma = True +force_grid_wrap = 0 +use_parentheses = True +ensure_newline_before_comments = True +line_length = 120 + [flake8] -format=pylint max-line-length = 120 +extend-ignore = E203, E266, E501, W503, E741 exclude = .svn,CVS,.bzr,.hg,.git,__pycache__,venv/*,src/*,tests/unit/common/protos/* max-complexity=16 @@ -9,9 +17,5 @@ norecursedirs = common workflows spark log_cli = true log_cli_level = 20 -[pep8] -max-line-length = 120 - [coverage:run] branch = True - diff --git a/setup.py b/setup.py index cead7aeb85..cb19deaa8d 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,7 @@ from __future__ import absolute_import -from setuptools import setup, find_packages # noqa +from setuptools import find_packages, setup # noqa + import flytekit # noqa from flytekit.tools.lazy_loader import LazyLoadPlugin # noqa @@ -9,24 +10,24 @@ extras_require[':python_version<"3"'] = [ "configparser>=3.0.0,<4.0.0", "futures>=3.2.0,<4.0.0", - "pathlib2>=2.3.2,<3.0.0" + "pathlib2>=2.3.2,<3.0.0", ] setup( - name='flytekit', + name="flytekit", version=flytekit.__version__, - maintainer='Lyft', - maintainer_email='flyte-eng@lyft.com', + maintainer="Lyft", + maintainer_email="flyte-eng@lyft.com", packages=find_packages(exclude=["tests*"]), - url='https://github.com/lyft/flytekit', - description='Flyte SDK for Python', - long_description=open('README.md').read(), + url="https://github.com/lyft/flytekit", + description="Flyte SDK for Python", + long_description=open("README.md").read(), long_description_content_type="text/markdown", entry_points={ - 'console_scripts': [ - 'pyflyte-execute=flytekit.bin.entrypoint:execute_task_cmd', - 'pyflyte=flytekit.clis.sdk_in_container.pyflyte:main', - 'flyte-cli=flytekit.clis.flyte_cli.main:_flyte_cli' + "console_scripts": [ + "pyflyte-execute=flytekit.bin.entrypoint:execute_task_cmd", + "pyflyte=flytekit.clis.sdk_in_container.pyflyte:main", + "flyte-cli=flytekit.clis.flyte_cli.main:_flyte_cli", ] }, install_requires=[ @@ -52,10 +53,10 @@ ], extras_require=extras_require, scripts=[ - 'scripts/flytekit_install_spark.sh', - 'scripts/flytekit_install_spark3.sh', - 'scripts/flytekit_build_image.sh', - 'scripts/flytekit_venv' + "scripts/flytekit_install_spark.sh", + "scripts/flytekit_install_spark3.sh", + "scripts/flytekit_build_image.sh", + "scripts/flytekit_venv", ], license="apache2", python_requires=">=2.7", diff --git a/tests/flytekit/common/parameterizers.py b/tests/flytekit/common/parameterizers.py index 3f656eb790..13d6704e96 100644 --- a/tests/flytekit/common/parameterizers.py +++ b/tests/flytekit/common/parameterizers.py @@ -2,11 +2,14 @@ from datetime import timedelta from itertools import product + from six.moves import range -from flytekit.models import interface, literals, types, task -from flytekit.models.core import identifier, types as _core_types -from flytekit.common.types.impl import blobs as _blob_impl, schema as _schema_impl +from flytekit.common.types.impl import blobs as _blob_impl +from flytekit.common.types.impl import schema as _schema_impl +from flytekit.models import interface, literals, task, types +from flytekit.models.core import identifier +from flytekit.models.core import types as _core_types LIST_OF_SCALAR_LITERAL_TYPES = [ types.LiteralType(simple=types.SimpleType.BINARY), @@ -19,39 +22,29 @@ types.LiteralType(simple=types.SimpleType.NONE), types.LiteralType(simple=types.SimpleType.STRING), types.LiteralType( - schema=types.SchemaType([ - types.SchemaType.SchemaColumn("a", types.SchemaType.SchemaColumn.SchemaColumnType.INTEGER), - types.SchemaType.SchemaColumn("b", types.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN), - types.SchemaType.SchemaColumn("c", types.SchemaType.SchemaColumn.SchemaColumnType.DATETIME), - types.SchemaType.SchemaColumn("d", types.SchemaType.SchemaColumn.SchemaColumnType.DURATION), - types.SchemaType.SchemaColumn("e", types.SchemaType.SchemaColumn.SchemaColumnType.FLOAT), - types.SchemaType.SchemaColumn("f", types.SchemaType.SchemaColumn.SchemaColumnType.STRING), - ]) + schema=types.SchemaType( + [ + types.SchemaType.SchemaColumn("a", types.SchemaType.SchemaColumn.SchemaColumnType.INTEGER), + types.SchemaType.SchemaColumn("b", types.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN), + types.SchemaType.SchemaColumn("c", types.SchemaType.SchemaColumn.SchemaColumnType.DATETIME), + types.SchemaType.SchemaColumn("d", types.SchemaType.SchemaColumn.SchemaColumnType.DURATION), + types.SchemaType.SchemaColumn("e", types.SchemaType.SchemaColumn.SchemaColumnType.FLOAT), + types.SchemaType.SchemaColumn("f", types.SchemaType.SchemaColumn.SchemaColumnType.STRING), + ] + ) ), types.LiteralType( - blob=_core_types.BlobType( - format="", - dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, - ) + blob=_core_types.BlobType(format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE,) ), types.LiteralType( - blob=_core_types.BlobType( - format="csv", - dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, - ) + blob=_core_types.BlobType(format="csv", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE,) ), types.LiteralType( - blob=_core_types.BlobType( - format="", - dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, - ) + blob=_core_types.BlobType(format="", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART,) ), types.LiteralType( - blob=_core_types.BlobType( - format="csv", - dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, - ) - ) + blob=_core_types.BlobType(format="csv", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART,) + ), ] @@ -63,15 +56,13 @@ types.LiteralType(collection_type=literal_type) for literal_type in LIST_OF_COLLECTION_LITERAL_TYPES ] -LIST_OF_ALL_LITERAL_TYPES = \ - LIST_OF_SCALAR_LITERAL_TYPES + \ - LIST_OF_COLLECTION_LITERAL_TYPES + \ - LIST_OF_NESTED_COLLECTION_LITERAL_TYPES +LIST_OF_ALL_LITERAL_TYPES = ( + LIST_OF_SCALAR_LITERAL_TYPES + LIST_OF_COLLECTION_LITERAL_TYPES + LIST_OF_NESTED_COLLECTION_LITERAL_TYPES +) LIST_OF_INTERFACES = [ interface.TypedInterface( - {'a': interface.Variable(t, "description 1")}, - {'b': interface.Variable(t, "description 2")} + {"a": interface.Variable(t, "description 1")}, {"b": interface.Variable(t, "description 2")}, ) for t in LIST_OF_ALL_LITERAL_TYPES ] @@ -81,13 +72,11 @@ task.Resources.ResourceEntry(task.Resources.ResourceName.CPU, "1"), task.Resources.ResourceEntry(task.Resources.ResourceName.GPU, "1"), task.Resources.ResourceEntry(task.Resources.ResourceName.MEMORY, "1G"), - task.Resources.ResourceEntry(task.Resources.ResourceName.STORAGE, "1G") + task.Resources.ResourceEntry(task.Resources.ResourceName.STORAGE, "1G"), ] -LIST_OF_RESOURCE_ENTRY_LISTS = [ - LIST_OF_RESOURCE_ENTRIES -] +LIST_OF_RESOURCE_ENTRY_LISTS = [LIST_OF_RESOURCE_ENTRIES] LIST_OF_RESOURCES = [ @@ -98,29 +87,17 @@ LIST_OF_RUNTIME_METADATA = [ task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.OTHER, "1.0.0", "python"), - task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0b0", "golang") + task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0b0", "golang"), ] -LIST_OF_RETRY_POLICIES = [ - literals.RetryStrategy(retries=i) for i in [0, 1, 3, 100] -] +LIST_OF_RETRY_POLICIES = [literals.RetryStrategy(retries=i) for i in [0, 1, 3, 100]] -LIST_OF_INTERRUPTIBLE = [ - None, - True, - False -] +LIST_OF_INTERRUPTIBLE = [None, True, False] LIST_OF_TASK_METADATA = [ task.TaskMetadata( - discoverable, - runtime_metadata, - timeout, - retry_strategy, - interruptible, - discovery_version, - deprecated + discoverable, runtime_metadata, timeout, retry_strategy, interruptible, discovery_version, deprecated, ) for discoverable, runtime_metadata, timeout, retry_strategy, interruptible, discovery_version, deprecated in product( [True, False], @@ -129,7 +106,7 @@ LIST_OF_RETRY_POLICIES, LIST_OF_INTERRUPTIBLE, ["1.0"], - ["deprecated"] + ["deprecated"], ) ] @@ -140,31 +117,17 @@ "python", task_metadata, interfaces, - {'a': 1, 'b': [1, 2, 3], 'c': 'abc', 'd': {'x': 1, 'y': 2, 'z': 3}}, + {"a": 1, "b": [1, 2, 3], "c": "abc", "d": {"x": 1, "y": 2, "z": 3}}, container=task.Container( - "my_image", - ["this", "is", "a", "cmd"], - ["this", "is", "an", "arg"], - resources, - {'a': 'b'}, - {'d': 'e'} - ) - ) - for task_metadata, interfaces, resources in product( - LIST_OF_TASK_METADATA, - LIST_OF_INTERFACES, - LIST_OF_RESOURCES + "my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], resources, {"a": "b"}, {"d": "e"}, + ), ) + for task_metadata, interfaces, resources in product(LIST_OF_TASK_METADATA, LIST_OF_INTERFACES, LIST_OF_RESOURCES) ] LIST_OF_CONTAINERS = [ task.Container( - "my_image", - ["this", "is", "a", "cmd"], - ["this", "is", "an", "arg"], - resources, - {'a': 'b'}, - {'d': 'e'} + "my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], resources, {"a": "b"}, {"d": "e"}, ) for resources in LIST_OF_RESOURCES ] @@ -175,93 +138,77 @@ (literals.Scalar(primitive=literals.Primitive(integer=100)), 100), (literals.Scalar(primitive=literals.Primitive(float_value=500.0)), 500.0), (literals.Scalar(primitive=literals.Primitive(boolean=True)), True), - (literals.Scalar(primitive=literals.Primitive(string_value='hello')), 'hello'), - (literals.Scalar(primitive=literals.Primitive(duration=timedelta(seconds=5))), timedelta(seconds=5)), + (literals.Scalar(primitive=literals.Primitive(string_value="hello")), "hello"), + (literals.Scalar(primitive=literals.Primitive(duration=timedelta(seconds=5))), timedelta(seconds=5),), (literals.Scalar(none_type=literals.Void()), None), ( literals.Scalar( blob=literals.Blob( - literals.BlobMetadata( - _core_types.BlobType( - "csv", - _core_types.BlobType.BlobDimensionality.SINGLE - ) - ), - "s3://some/where" + literals.BlobMetadata(_core_types.BlobType("csv", _core_types.BlobType.BlobDimensionality.SINGLE)), + "s3://some/where", ) ), - _blob_impl.Blob("s3://some/where", format="csv") + _blob_impl.Blob("s3://some/where", format="csv"), ), ( literals.Scalar( blob=literals.Blob( - literals.BlobMetadata( - _core_types.BlobType( - "", - _core_types.BlobType.BlobDimensionality.SINGLE - ) - ), - "s3://some/where" + literals.BlobMetadata(_core_types.BlobType("", _core_types.BlobType.BlobDimensionality.SINGLE)), + "s3://some/where", ) ), - _blob_impl.Blob("s3://some/where") + _blob_impl.Blob("s3://some/where"), ), ( literals.Scalar( blob=literals.Blob( - literals.BlobMetadata( - _core_types.BlobType( - "csv", - _core_types.BlobType.BlobDimensionality.MULTIPART - ) - ), - "s3://some/where/" + literals.BlobMetadata(_core_types.BlobType("csv", _core_types.BlobType.BlobDimensionality.MULTIPART)), + "s3://some/where/", ) ), - _blob_impl.MultiPartBlob("s3://some/where/", format="csv") + _blob_impl.MultiPartBlob("s3://some/where/", format="csv"), ), ( literals.Scalar( blob=literals.Blob( - literals.BlobMetadata( - _core_types.BlobType( - "", - _core_types.BlobType.BlobDimensionality.MULTIPART - ) - ), - "s3://some/where/" + literals.BlobMetadata(_core_types.BlobType("", _core_types.BlobType.BlobDimensionality.MULTIPART)), + "s3://some/where/", ) ), - _blob_impl.MultiPartBlob("s3://some/where/") + _blob_impl.MultiPartBlob("s3://some/where/"), ), ( literals.Scalar( schema=literals.Schema( "s3://some/where/", - types.SchemaType([ - types.SchemaType.SchemaColumn("a", types.SchemaType.SchemaColumn.SchemaColumnType.INTEGER), - types.SchemaType.SchemaColumn("b", types.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN), - types.SchemaType.SchemaColumn("c", types.SchemaType.SchemaColumn.SchemaColumnType.DATETIME), - types.SchemaType.SchemaColumn("d", types.SchemaType.SchemaColumn.SchemaColumnType.DURATION), - types.SchemaType.SchemaColumn("e", types.SchemaType.SchemaColumn.SchemaColumnType.FLOAT), - types.SchemaType.SchemaColumn("f", types.SchemaType.SchemaColumn.SchemaColumnType.STRING), - ]) + types.SchemaType( + [ + types.SchemaType.SchemaColumn("a", types.SchemaType.SchemaColumn.SchemaColumnType.INTEGER), + types.SchemaType.SchemaColumn("b", types.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN), + types.SchemaType.SchemaColumn("c", types.SchemaType.SchemaColumn.SchemaColumnType.DATETIME), + types.SchemaType.SchemaColumn("d", types.SchemaType.SchemaColumn.SchemaColumnType.DURATION), + types.SchemaType.SchemaColumn("e", types.SchemaType.SchemaColumn.SchemaColumnType.FLOAT), + types.SchemaType.SchemaColumn("f", types.SchemaType.SchemaColumn.SchemaColumnType.STRING), + ] + ), ) ), _schema_impl.Schema( "s3://some/where/", _schema_impl.SchemaType.promote_from_model( - types.SchemaType([ - types.SchemaType.SchemaColumn("a", types.SchemaType.SchemaColumn.SchemaColumnType.INTEGER), - types.SchemaType.SchemaColumn("b", types.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN), - types.SchemaType.SchemaColumn("c", types.SchemaType.SchemaColumn.SchemaColumnType.DATETIME), - types.SchemaType.SchemaColumn("d", types.SchemaType.SchemaColumn.SchemaColumnType.DURATION), - types.SchemaType.SchemaColumn("e", types.SchemaType.SchemaColumn.SchemaColumnType.FLOAT), - types.SchemaType.SchemaColumn("f", types.SchemaType.SchemaColumn.SchemaColumnType.STRING), - ]) - ) - ) - ) + types.SchemaType( + [ + types.SchemaType.SchemaColumn("a", types.SchemaType.SchemaColumn.SchemaColumnType.INTEGER), + types.SchemaType.SchemaColumn("b", types.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN), + types.SchemaType.SchemaColumn("c", types.SchemaType.SchemaColumn.SchemaColumnType.DATETIME), + types.SchemaType.SchemaColumn("d", types.SchemaType.SchemaColumn.SchemaColumnType.DURATION), + types.SchemaType.SchemaColumn("e", types.SchemaType.SchemaColumn.SchemaColumnType.FLOAT), + types.SchemaType.SchemaColumn("f", types.SchemaType.SchemaColumn.SchemaColumnType.STRING), + ] + ) + ), + ), + ), ] LIST_OF_SCALAR_LITERALS_AND_PYTHON_VALUE = [ @@ -272,5 +219,6 @@ (literals.LiteralCollection(literals=[l, l, l]), [v, v, v]) for l, v in LIST_OF_SCALAR_LITERALS_AND_PYTHON_VALUE ] -LIST_OF_ALL_LITERALS_AND_VALUES = LIST_OF_SCALAR_LITERALS_AND_PYTHON_VALUE + \ - LIST_OF_LITERAL_COLLECTIONS_AND_PYTHON_VALUE +LIST_OF_ALL_LITERALS_AND_VALUES = ( + LIST_OF_SCALAR_LITERALS_AND_PYTHON_VALUE + LIST_OF_LITERAL_COLLECTIONS_AND_PYTHON_VALUE +) diff --git a/tests/flytekit/common/task_definitions.py b/tests/flytekit/common/task_definitions.py index 85826a28d7..0381b203d1 100644 --- a/tests/flytekit/common/task_definitions.py +++ b/tests/flytekit/common/task_definitions.py @@ -1,5 +1,6 @@ from __future__ import absolute_import -from flytekit.sdk.tasks import python_task, inputs, outputs + +from flytekit.sdk.tasks import inputs, outputs, python_task from flytekit.sdk.types import Types diff --git a/tests/flytekit/common/workflows/batch.py b/tests/flytekit/common/workflows/batch.py index 9a1f5b568b..ed460adc92 100644 --- a/tests/flytekit/common/workflows/batch.py +++ b/tests/flytekit/common/workflows/batch.py @@ -1,12 +1,10 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function from six import moves as _six_moves -from flytekit.sdk.tasks import python_task, inputs, outputs, dynamic_task +from flytekit.sdk.tasks import dynamic_task, inputs, outputs, python_task from flytekit.sdk.types import Types -from flytekit.sdk.workflow import workflow_class, Output, Input +from flytekit.sdk.workflow import Input, Output, workflow_class @outputs(out_ints=[Types.Integer]) @@ -54,7 +52,7 @@ def no_inputs_sample_batch_task(wf_params, out_str, out_ints): @inputs(in1=Types.Integer) @outputs(out_str=[Types.String]) -@dynamic_task(cache=True, cache_version='1') +@dynamic_task(cache=True, cache_version="1") def sample_batch_task_beatles_cached(wf_params, in1, out_str): wf_params.stats.incr("task_run") res2 = [] @@ -67,7 +65,7 @@ def sample_batch_task_beatles_cached(wf_params, in1, out_str): @inputs(in1=Types.Integer) @outputs(out1=Types.String) -@python_task(cache=True, cache_version='1') +@python_task(cache=True, cache_version="1") def sample_beatles_lyrics_cached(wf_params, in1, out1): wf_params.stats.incr("task_run") lyrics = ["Ob-La-Di, Ob-La-Da", "When I'm 64", "Yesterday"] @@ -99,15 +97,26 @@ def sq_sub_task(wf_params, in1, out1): @inputs(ints_to_print=[[Types.Integer]], strings_to_print=[Types.String]) -@python_task(cache_version='1') +@python_task(cache_version="1") def print_every_time(wf_params, ints_to_print, strings_to_print): wf_params.stats.incr("task_run") print("Expected Int values: {}".format([[0, 0, 0], [1, 2, 3], [2, 4, 6], [0, 1, 4], [0, 1, 4]])) print("Actual Int values: {}".format(ints_to_print)) - print("Expected String values: {}".format( - [u"I'm the first result", u"hello 0", u"I'm after each sub-task result", u'hello 1', - u"I'm after each sub-task result", u'hello 2', u"I'm after each sub-task result", u"I'm the last result"])) + print( + "Expected String values: {}".format( + [ + u"I'm the first result", + u"hello 0", + u"I'm after each sub-task result", + u"hello 1", + u"I'm after each sub-task result", + u"hello 2", + u"I'm after each sub-task result", + u"I'm the last result", + ] + ) + ) print("Actual String values: {}".format(strings_to_print)) diff --git a/tests/flytekit/common/workflows/dynamic_workflows.py b/tests/flytekit/common/workflows/dynamic_workflows.py index 2e254f1e68..97f5e81343 100644 --- a/tests/flytekit/common/workflows/dynamic_workflows.py +++ b/tests/flytekit/common/workflows/dynamic_workflows.py @@ -1,8 +1,9 @@ from __future__ import absolute_import, division, print_function -from flytekit.sdk import tasks as _tasks, workflow as _workflow +from flytekit.sdk import tasks as _tasks +from flytekit.sdk import workflow as _workflow from flytekit.sdk.types import Types as _Types -from flytekit.sdk.workflow import workflow_class, Input, Output +from flytekit.sdk.workflow import Input, Output, workflow_class @_tasks.inputs(num=_Types.Integer) diff --git a/tests/flytekit/common/workflows/failing_workflows.py b/tests/flytekit/common/workflows/failing_workflows.py index f02e7234cd..50d8aa0c63 100644 --- a/tests/flytekit/common/workflows/failing_workflows.py +++ b/tests/flytekit/common/workflows/failing_workflows.py @@ -1,10 +1,8 @@ from __future__ import absolute_import, division, print_function -from flytekit.sdk import tasks as _tasks, workflow as _workflow -from flytekit.sdk.tasks import python_task -from flytekit.sdk.types import Types as _Types -from flytekit.sdk.workflow import workflow_class, Input, Output from flytekit.models.core.workflow import WorkflowMetadata +from flytekit.sdk.tasks import python_task +from flytekit.sdk.workflow import workflow_class @python_task diff --git a/tests/flytekit/common/workflows/gpu.py b/tests/flytekit/common/workflows/gpu.py index c07f50e0a8..a1c432211f 100644 --- a/tests/flytekit/common/workflows/gpu.py +++ b/tests/flytekit/common/workflows/gpu.py @@ -1,8 +1,8 @@ from __future__ import absolute_import, print_function -from flytekit.sdk.tasks import python_task, inputs, outputs +from flytekit.sdk.tasks import inputs, outputs, python_task from flytekit.sdk.types import Types -from flytekit.sdk.workflow import workflow_class, Input, Output +from flytekit.sdk.workflow import Input, Output, workflow_class @inputs(a=Types.Integer) @@ -17,6 +17,6 @@ def add_one(wf_params, a, b): @workflow_class class SimpleWorkflow(object): input_1 = Input(Types.Integer) - input_2 = Input(Types.Integer, default=5, help='Not required.') + input_2 = Input(Types.Integer, default=5, help="Not required.") a = add_one(a=input_1) output = Output(a.outputs.b, sdk_type=Types.Integer) diff --git a/tests/flytekit/common/workflows/hive.py b/tests/flytekit/common/workflows/hive.py index ca605dae39..05f640467c 100644 --- a/tests/flytekit/common/workflows/hive.py +++ b/tests/flytekit/common/workflows/hive.py @@ -1,14 +1,14 @@ -from __future__ import absolute_import -from __future__ import print_function +from __future__ import absolute_import, print_function import six as _six -from flytekit.sdk.tasks import qubole_hive_task, outputs, inputs, python_task -from flytekit.sdk.workflow import workflow_class + +from flytekit.sdk.tasks import inputs, outputs, python_task, qubole_hive_task from flytekit.sdk.types import Types +from flytekit.sdk.workflow import workflow_class @outputs(hive_results=[Types.Schema()]) -@qubole_hive_task(tags=[_six.text_type('these'), _six.text_type('are'), _six.text_type('tags')]) +@qubole_hive_task(tags=[_six.text_type("these"), _six.text_type("are"), _six.text_type("tags")]) def generate_queries(wf_params, hive_results): q1 = "SELECT 1" q2 = "SELECT 'two'" diff --git a/tests/flytekit/common/workflows/nested.py b/tests/flytekit/common/workflows/nested.py index 85c0b6895d..99a5ff89a9 100644 --- a/tests/flytekit/common/workflows/nested.py +++ b/tests/flytekit/common/workflows/nested.py @@ -1,8 +1,8 @@ from __future__ import absolute_import, print_function -from flytekit.sdk.tasks import python_task, inputs, outputs +from flytekit.sdk.tasks import inputs, outputs, python_task from flytekit.sdk.types import Types -from flytekit.sdk.workflow import workflow_class, Input, Output +from flytekit.sdk.workflow import Input, Output, workflow_class @inputs(a=Types.Integer) @@ -29,7 +29,7 @@ def sum(wf_params, a, b, c): @workflow_class class Child(object): input_1 = Input(Types.Integer) - input_2 = Input(Types.Integer, default=5, help='Not required.') + input_2 = Input(Types.Integer, default=5, help="Not required.") a = add_one(a=input_1) b = add_one(a=input_2) c = add_one(a=100) diff --git a/tests/flytekit/common/workflows/notebook.py b/tests/flytekit/common/workflows/notebook.py index bef7e5b08f..bd61c6a94f 100644 --- a/tests/flytekit/common/workflows/notebook.py +++ b/tests/flytekit/common/workflows/notebook.py @@ -1,24 +1,24 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function -from flytekit.sdk.types import Types +from flytekit.contrib.notebook.tasks import python_notebook, spark_notebook from flytekit.sdk.tasks import inputs, outputs +from flytekit.sdk.types import Types +from flytekit.sdk.workflow import Input, workflow_class -from flytekit.sdk.workflow import workflow_class, Input -from flytekit.contrib.notebook.tasks import python_notebook, spark_notebook +interactive_python = python_notebook( + notebook_path="../../../../notebook-task-examples/python-notebook.ipynb", + inputs=inputs(pi=Types.Float), + outputs=outputs(out=Types.Float), + cpu_request="1", + memory_request="1G", +) -interactive_python = python_notebook(notebook_path="../../../../notebook-task-examples/python-notebook.ipynb", - inputs=inputs(pi=Types.Float), - outputs=outputs(out=Types.Float), - cpu_request="1", - memory_request="1G" - ) +interactive_spark = spark_notebook( + notebook_path="../../../../notebook-task-examples/spark-notebook-pi.ipynb", + inputs=inputs(partitions=Types.Integer), + outputs=outputs(pi=Types.Float), +) -interactive_spark = spark_notebook(notebook_path="../../../../notebook-task-examples/spark-notebook-pi.ipynb", - inputs=inputs(partitions=Types.Integer), - outputs=outputs(pi=Types.Float), - ) @workflow_class class FlyteNotebookSparkWorkflow(object): diff --git a/tests/flytekit/common/workflows/notifications.py b/tests/flytekit/common/workflows/notifications.py index 3e3997c0c8..79031753e6 100644 --- a/tests/flytekit/common/workflows/notifications.py +++ b/tests/flytekit/common/workflows/notifications.py @@ -1,10 +1,10 @@ from __future__ import absolute_import, print_function -from flytekit.sdk.tasks import python_task, inputs, outputs -from flytekit.sdk.types import Types -from flytekit.sdk.workflow import workflow_class, Input from flytekit.common import notifications as _notifications from flytekit.models.core import execution as _execution +from flytekit.sdk.tasks import inputs, outputs, python_task +from flytekit.sdk.types import Types +from flytekit.sdk.workflow import Input, workflow_class @inputs(a=Types.Integer, b=Types.Integer) @@ -17,12 +17,20 @@ def add_two_integers(wf_params, a, b, c): @workflow_class class BasicWorkflow(object): input_1 = Input(Types.Integer) - input_2 = Input(Types.Integer, default=1, help='Not required.') + input_2 = Input(Types.Integer, default=1, help="Not required.") a = add_two_integers(a=input_1, b=input_2) -notification_lp = BasicWorkflow.create_launch_plan(notifications=[ - _notifications.Email([_execution.WorkflowExecutionPhase.SUCCEEDED, _execution.WorkflowExecutionPhase.FAILED, - _execution.WorkflowExecutionPhase.TIMED_OUT, _execution.WorkflowExecutionPhase.ABORTED], - ['flyte-test-notifications@mydomain.com']) -]) +notification_lp = BasicWorkflow.create_launch_plan( + notifications=[ + _notifications.Email( + [ + _execution.WorkflowExecutionPhase.SUCCEEDED, + _execution.WorkflowExecutionPhase.FAILED, + _execution.WorkflowExecutionPhase.TIMED_OUT, + _execution.WorkflowExecutionPhase.ABORTED, + ], + ["flyte-test-notifications@mydomain.com"], + ) + ] +) diff --git a/tests/flytekit/common/workflows/presto.py b/tests/flytekit/common/workflows/presto.py index 7c1988f809..317df496c7 100644 --- a/tests/flytekit/common/workflows/presto.py +++ b/tests/flytekit/common/workflows/presto.py @@ -1,9 +1,9 @@ from __future__ import absolute_import +from flytekit.common.tasks.presto_task import SdkPrestoTask from flytekit.sdk.tasks import inputs from flytekit.sdk.types import Types -from flytekit.sdk.workflow import workflow_class, Input, Output -from flytekit.common.tasks.presto_task import SdkPrestoTask +from flytekit.sdk.workflow import Input, Output, workflow_class schema = Types.Schema([("a", Types.String), ("b", Types.Integer)]) @@ -22,6 +22,6 @@ class PrestoWorkflow(object): ds = Input(Types.String, required=True, help="Test string with no default") # routing_group = Input(Types.String, required=True, help="Test string with no default") - p_task = presto_task(ds=ds, rg='etl') + p_task = presto_task(ds=ds, rg="etl") output_a = Output(p_task.outputs.results, sdk_type=schema) diff --git a/tests/flytekit/common/workflows/python.py b/tests/flytekit/common/workflows/python.py index 794f6ff838..d9cd9a9c7d 100644 --- a/tests/flytekit/common/workflows/python.py +++ b/tests/flytekit/common/workflows/python.py @@ -1,13 +1,13 @@ from __future__ import absolute_import, division, print_function -from flytekit.sdk.tasks import python_task, inputs, outputs +from flytekit.sdk.tasks import inputs, outputs, python_task from flytekit.sdk.types import Types -from flytekit.sdk.workflow import workflow_class, Input +from flytekit.sdk.workflow import Input, workflow_class @inputs(value_to_print=Types.Integer) @outputs(out=Types.Integer) -@python_task(cache_version='1') +@python_task(cache_version="1") def add_one_and_print(workflow_parameters, value_to_print, out): workflow_parameters.stats.incr("task_run") added = value_to_print + 1 @@ -17,7 +17,7 @@ def add_one_and_print(workflow_parameters, value_to_print, out): @inputs(value1_to_print=Types.Integer, value2_to_print=Types.Integer) @outputs(out=Types.Integer) -@python_task(cache_version='1') +@python_task(cache_version="1") def sum_non_none(workflow_parameters, value1_to_print, value2_to_print, out): workflow_parameters.stats.incr("task_run") added = 0 @@ -30,10 +30,11 @@ def sum_non_none(workflow_parameters, value1_to_print, value2_to_print, out): out.set(added) -@inputs(value1_to_add=Types.Integer, value2_to_add=Types.Integer, value3_to_add=Types.Integer, - value4_to_add=Types.Integer) +@inputs( + value1_to_add=Types.Integer, value2_to_add=Types.Integer, value3_to_add=Types.Integer, value4_to_add=Types.Integer, +) @outputs(out=Types.Integer) -@python_task(cache_version='1') +@python_task(cache_version="1") def sum_and_print(workflow_parameters, value1_to_add, value2_to_add, value3_to_add, value4_to_add, out): workflow_parameters.stats.incr("task_run") summed = sum([value1_to_add, value2_to_add, value3_to_add, value4_to_add]) @@ -42,7 +43,7 @@ def sum_and_print(workflow_parameters, value1_to_add, value2_to_add, value3_to_a @inputs(value_to_print=Types.Integer, date_triggered=Types.Datetime) -@python_task(cache_version='1') +@python_task(cache_version="1") def print_every_time(workflow_parameters, value_to_print, date_triggered): workflow_parameters.stats.incr("task_run") print("My printed value: {} @ {}".format(value_to_print, date_triggered)) @@ -53,16 +54,13 @@ class PythonTasksWorkflow(object): triggered_date = Input(Types.Datetime) print1a = add_one_and_print(value_to_print=3) print1b = add_one_and_print(value_to_print=101) - print2 = sum_non_none(value1_to_print=print1a.outputs.out, - value2_to_print=print1b.outputs.out) + print2 = sum_non_none(value1_to_print=print1a.outputs.out, value2_to_print=print1b.outputs.out) print3 = add_one_and_print(value_to_print=print2.outputs.out) print4 = add_one_and_print(value_to_print=print3.outputs.out) print_sum = sum_and_print( value1_to_add=print2.outputs.out, value2_to_add=print3.outputs.out, value3_to_add=print4.outputs.out, - value4_to_add=100 + value4_to_add=100, ) - print_always = print_every_time( - value_to_print=print_sum.outputs.out, - date_triggered=triggered_date) + print_always = print_every_time(value_to_print=print_sum.outputs.out, date_triggered=triggered_date) diff --git a/tests/flytekit/common/workflows/raw_container.py b/tests/flytekit/common/workflows/raw_container.py index e260a7acc6..ed3f786174 100644 --- a/tests/flytekit/common/workflows/raw_container.py +++ b/tests/flytekit/common/workflows/raw_container.py @@ -2,7 +2,7 @@ from flytekit.common.tasks.raw_container import SdkRawContainerTask from flytekit.sdk.types import Types -from flytekit.sdk.workflow import workflow_class, Input, Output +from flytekit.sdk.workflow import Input, Output, workflow_class square = SdkRawContainerTask( input_data_dir="/var/inputs", diff --git a/tests/flytekit/common/workflows/raw_edge_detector.py b/tests/flytekit/common/workflows/raw_edge_detector.py index 295a846f0f..eeefac4de1 100644 --- a/tests/flytekit/common/workflows/raw_edge_detector.py +++ b/tests/flytekit/common/workflows/raw_edge_detector.py @@ -1,6 +1,6 @@ from flytekit.common.tasks.raw_container import SdkRawContainerTask from flytekit.sdk.types import Types -from flytekit.sdk.workflow import workflow_class, Input, Output +from flytekit.sdk.workflow import Input, Output, workflow_class edges = SdkRawContainerTask( input_data_dir="/inputs", @@ -17,4 +17,4 @@ class EdgeDetector(object): script = Input(Types.Blob) image = Input(Types.Blob) edge_task = edges(script=script, image=image) - out = Output(edge_task.outputs.edges, sdk_type=Types.Blob) + out = Output(edge_task.outputs.edges, sdk_type=Types.Blob) diff --git a/tests/flytekit/common/workflows/scala_spark.py b/tests/flytekit/common/workflows/scala_spark.py index 833074b398..37a6abe0c7 100644 --- a/tests/flytekit/common/workflows/scala_spark.py +++ b/tests/flytekit/common/workflows/scala_spark.py @@ -1,12 +1,9 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function +from flytekit.sdk.spark_types import SparkType from flytekit.sdk.tasks import generic_spark_task, inputs, python_task from flytekit.sdk.types import Types -from flytekit.sdk.spark_types import SparkType -from flytekit.sdk.workflow import workflow_class, Input - +from flytekit.sdk.workflow import Input, workflow_class scala_spark = generic_spark_task( spark_type=SparkType.SCALA, @@ -14,17 +11,17 @@ main_class="org.apache.spark.examples.SparkPi", main_application_file="local:///opt/spark/examples/jars/spark-examples.jar", spark_conf={ - 'spark.driver.memory': "1000M", - 'spark.executor.memory': "1000M", - 'spark.executor.cores': '1', - 'spark.executor.instances': '2', - }, - cache_version='1' + "spark.driver.memory": "1000M", + "spark.executor.memory": "1000M", + "spark.executor.cores": "1", + "spark.executor.instances": "2", + }, + cache_version="1", ) @inputs(date_triggered=Types.Datetime) -@python_task(cache_version='1') +@python_task(cache_version="1") def print_every_time(workflow_parameters, date_triggered): print("My input : {}".format(date_triggered)) @@ -34,5 +31,4 @@ class SparkTasksWorkflow(object): triggered_date = Input(Types.Datetime) partitions = Input(Types.Integer) spark_task = scala_spark(partitions=partitions) - print_always = print_every_time( - date_triggered=triggered_date) + print_always = print_every_time(date_triggered=triggered_date) diff --git a/tests/flytekit/common/workflows/sidecar.py b/tests/flytekit/common/workflows/sidecar.py index d4cdc3998c..9343b29755 100644 --- a/tests/flytekit/common/workflows/sidecar.py +++ b/tests/flytekit/common/workflows/sidecar.py @@ -1,50 +1,43 @@ -from __future__ import absolute_import -from __future__ import print_function +from __future__ import absolute_import, print_function import os import time +from k8s.io.api.core.v1 import generated_pb2 + from flytekit.sdk.tasks import sidecar_task -from flytekit.sdk.workflow import workflow_class, Input from flytekit.sdk.types import Types -from k8s.io.api.core.v1 import generated_pb2 +from flytekit.sdk.workflow import Input, workflow_class def generate_pod_spec_for_task(): pod_spec = generated_pb2.PodSpec() - secondary_container = generated_pb2.Container( - name="secondary", - image="alpine", - ) + secondary_container = generated_pb2.Container(name="secondary", image="alpine",) secondary_container.command.extend(["/bin/sh"]) secondary_container.args.extend(["-c", "echo hi sidecar world > /data/message.txt"]) - shared_volume_mount = generated_pb2.VolumeMount( - name="shared-data", - mountPath="/data", - ) + shared_volume_mount = generated_pb2.VolumeMount(name="shared-data", mountPath="/data",) secondary_container.volumeMounts.extend([shared_volume_mount]) primary_container = generated_pb2.Container(name="primary") primary_container.volumeMounts.extend([shared_volume_mount]) - pod_spec.volumes.extend([generated_pb2.Volume( - name="shared-data", - volumeSource=generated_pb2.VolumeSource( - emptyDir=generated_pb2.EmptyDirVolumeSource( - medium="Memory", + pod_spec.volumes.extend( + [ + generated_pb2.Volume( + name="shared-data", + volumeSource=generated_pb2.VolumeSource(emptyDir=generated_pb2.EmptyDirVolumeSource(medium="Memory",)), ) - ) - )]) + ] + ) pod_spec.containers.extend([primary_container, secondary_container]) return pod_spec @sidecar_task( - pod_spec=generate_pod_spec_for_task(), - primary_container_name="primary", + pod_spec=generate_pod_spec_for_task(), primary_container_name="primary", ) def a_sidecar_task(wfparams): - while not os.path.isfile('/data/message.txt'): + while not os.path.isfile("/data/message.txt"): time.sleep(5) diff --git a/tests/flytekit/common/workflows/simple.py b/tests/flytekit/common/workflows/simple.py index 3f86df768a..1914bcf7bf 100644 --- a/tests/flytekit/common/workflows/simple.py +++ b/tests/flytekit/common/workflows/simple.py @@ -1,10 +1,11 @@ from __future__ import absolute_import, print_function -from flytekit.sdk.tasks import python_task, inputs, outputs -from flytekit.sdk.types import Types -from flytekit.sdk.workflow import workflow_class, Input import pandas as _pd +from flytekit.sdk.tasks import inputs, outputs, python_task +from flytekit.sdk.types import Types +from flytekit.sdk.workflow import Input, workflow_class + @inputs(a=Types.Integer) @outputs(b=Types.Integer) @@ -15,7 +16,7 @@ def add_one(wf_params, a, b): @inputs(a=Types.Integer) @outputs(b=Types.Integer) -@python_task(cache=True, cache_version='1') +@python_task(cache=True, cache_version="1") def subtract_one(wf_params, a, b): b.set(a - 1) @@ -25,39 +26,34 @@ def subtract_one(wf_params, a, b): b=Types.CSV, c=Types.MultiPartCSV, d=Types.MultiPartBlob, - e=Types.Schema( - [ - ('a', Types.Integer), - ('b', Types.Integer) - ] - ) + e=Types.Schema([("a", Types.Integer), ("b", Types.Integer)]), ) @python_task def write_special_types(wf_params, a, b, c, d, e): blob = Types.Blob() with blob as w: - w.write("hello I'm a blob".encode('utf-8')) + w.write("hello I'm a blob".encode("utf-8")) csv = Types.CSV() with csv as w: w.write("hello,i,iz,blob") mpcsv = Types.MultiPartCSV() - with mpcsv.create_part('000000') as w: + with mpcsv.create_part("000000") as w: w.write("hello,i,iz,blob") - with mpcsv.create_part('000001') as w: + with mpcsv.create_part("000001") as w: w.write("hello,i,iz,blob2") mpblob = Types.MultiPartBlob() - with mpblob.create_part('000000') as w: - w.write("hello I'm a mp blob".encode('utf-8')) - with mpblob.create_part('000001') as w: - w.write("hello I'm a mp blob too".encode('utf-8')) + with mpblob.create_part("000000") as w: + w.write("hello I'm a mp blob".encode("utf-8")) + with mpblob.create_part("000001") as w: + w.write("hello I'm a mp blob too".encode("utf-8")) - schema = Types.Schema([('a', Types.Integer), ('b', Types.Integer)])() + schema = Types.Schema([("a", Types.Integer), ("b", Types.Integer)])() with schema as w: - w.write(_pd.DataFrame.from_dict({'a': [1, 2, 3], 'b': [4, 5, 6]})) - w.write(_pd.DataFrame.from_dict({'a': [3, 2, 1], 'b': [6, 5, 4]})) + w.write(_pd.DataFrame.from_dict({"a": [1, 2, 3], "b": [4, 5, 6]})) + w.write(_pd.DataFrame.from_dict({"a": [3, 2, 1], "b": [6, 5, 4]})) a.set(blob) b.set(csv) @@ -71,17 +67,12 @@ def write_special_types(wf_params, a, b, c, d, e): b=Types.CSV, c=Types.MultiPartCSV, d=Types.MultiPartBlob, - e=Types.Schema( - [ - ('a', Types.Integer), - ('b', Types.Integer) - ] - ) + e=Types.Schema([("a", Types.Integer), ("b", Types.Integer)]), ) @python_task def read_special_types(wf_params, a, b, c, d, e): with a as r: - assert r.read().decode('utf-8') == "hello I'm a blob" + assert r.read().decode("utf-8") == "hello I'm a blob" with b as r: assert r.read() == "hello,i,iz,blob" @@ -93,33 +84,27 @@ def read_special_types(wf_params, a, b, c, d, e): with d as r: assert len(r) == 2 - assert r[0].read().decode('utf-8') == "hello I'm a mp blob" - assert r[1].read().decode('utf-8') == "hello I'm a mp blob too" + assert r[0].read().decode("utf-8") == "hello I'm a mp blob" + assert r[1].read().decode("utf-8") == "hello I'm a mp blob too" with e as r: df = r.read() - assert df['a'].tolist() == [1, 2, 3] - assert df['b'].tolist() == [4, 5, 6] + assert df["a"].tolist() == [1, 2, 3] + assert df["b"].tolist() == [4, 5, 6] df = r.read() - assert df['a'].tolist() == [3, 2, 1] - assert df['b'].tolist() == [6, 5, 4] + assert df["a"].tolist() == [3, 2, 1] + assert df["b"].tolist() == [6, 5, 4] assert r.read() is None @workflow_class class SimpleWorkflow(object): input_1 = Input(Types.Integer) - input_2 = Input(Types.Integer, default=5, help='Not required.') + input_2 = Input(Types.Integer, default=5, help="Not required.") a = add_one(a=input_1) b = add_one(a=input_2) c = subtract_one(a=input_1) d = write_special_types() - e = read_special_types( - a=d.outputs.a, - b=d.outputs.b, - c=d.outputs.c, - d=d.outputs.d, - e=d.outputs.e, - ) + e = read_special_types(a=d.outputs.a, b=d.outputs.b, c=d.outputs.c, d=d.outputs.d, e=d.outputs.e,) diff --git a/tests/flytekit/common/workflows/spark.py b/tests/flytekit/common/workflows/spark.py index d6314e044e..2754395d25 100644 --- a/tests/flytekit/common/workflows/spark.py +++ b/tests/flytekit/common/workflows/spark.py @@ -1,29 +1,28 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function import random from operator import add from six.moves import range -from flytekit.sdk.tasks import spark_task, inputs, outputs, python_task +from flytekit.sdk.tasks import inputs, outputs, python_task, spark_task from flytekit.sdk.types import Types -from flytekit.sdk.workflow import workflow_class, Input +from flytekit.sdk.workflow import Input, workflow_class @inputs(partitions=Types.Integer) @outputs(out=Types.Float) @spark_task( spark_conf={ - 'spark.driver.memory': "1000M", - 'spark.executor.memory': "1000M", - 'spark.executor.cores': '1', - 'spark.executor.instances': '2', - 'spark.hadoop.mapred.output.committer.class': "org.apache.hadoop.mapred.DirectFileOutputCommitter", - 'spark.hadoop.mapreduce.use.directfileoutputcommitter': "true", + "spark.driver.memory": "1000M", + "spark.executor.memory": "1000M", + "spark.executor.cores": "1", + "spark.executor.instances": "2", + "spark.hadoop.mapred.output.committer.class": "org.apache.hadoop.mapred.DirectFileOutputCommitter", + "spark.hadoop.mapreduce.use.directfileoutputcommitter": "true", }, - cache_version='1') + cache_version="1", +) def hello_spark(workflow_parameters, spark_context, partitions, out): print("Starting Spark with Partitions: {}".format(partitions)) @@ -35,7 +34,7 @@ def hello_spark(workflow_parameters, spark_context, partitions, out): @inputs(value_to_print=Types.Float, date_triggered=Types.Datetime) -@python_task(cache_version='1') +@python_task(cache_version="1") def print_every_time(workflow_parameters, value_to_print, date_triggered): print("My printed value: {} @ {}".format(value_to_print, date_triggered)) @@ -50,6 +49,4 @@ def f(_): class SparkTasksWorkflow(object): triggered_date = Input(Types.Datetime) sparkTask = hello_spark(partitions=50) - print_always = print_every_time( - value_to_print=sparkTask.outputs.out, - date_triggered=triggered_date) + print_always = print_every_time(value_to_print=sparkTask.outputs.out, date_triggered=triggered_date) diff --git a/tests/flytekit/loadtests/cp_orchestrator.py b/tests/flytekit/loadtests/cp_orchestrator.py index 2b378b4337..e3b99a7d0b 100644 --- a/tests/flytekit/loadtests/cp_orchestrator.py +++ b/tests/flytekit/loadtests/cp_orchestrator.py @@ -1,9 +1,10 @@ from __future__ import absolute_import, division, print_function +from six.moves import range + from flytekit.sdk.workflow import workflow_class from tests.flytekit.loadtests.cp_python import FlyteCPPythonLoadTestWorkflow from tests.flytekit.loadtests.cp_spark import FlyteCPSparkLoadTestWorkflow -from six.moves import range # launch plans for individual load tests. python_loadtest_lp = FlyteCPPythonLoadTestWorkflow.create_launch_plan() diff --git a/tests/flytekit/loadtests/cp_python.py b/tests/flytekit/loadtests/cp_python.py index fddb8b2367..b3a0efd24b 100644 --- a/tests/flytekit/loadtests/cp_python.py +++ b/tests/flytekit/loadtests/cp_python.py @@ -2,17 +2,18 @@ import time -from flytekit.sdk.tasks import python_task, inputs, outputs +from six.moves import range + +from flytekit.sdk.tasks import inputs, outputs, python_task from flytekit.sdk.types import Types from flytekit.sdk.workflow import workflow_class -from six.moves import range @inputs(value1_to_add=Types.Integer, value2_to_add=Types.Integer) @outputs(out=Types.Integer) @python_task(cpu_request="1", cpu_limit="1", memory_request="3G") def sum_and_print(workflow_parameters, value1_to_add, value2_to_add, out): - for i in range(5*60): + for i in range(5 * 60): print("This is load test task. I have been running for {} seconds.".format(i)) time.sleep(1) @@ -26,7 +27,4 @@ class FlyteCPPythonLoadTestWorkflow(object): print_sum = [None] * 5 for i in range(0, 5): - print_sum[i] = sum_and_print( - value1_to_add=1, - value2_to_add=1 - ) + print_sum[i] = sum_and_print(value1_to_add=1, value2_to_add=1) diff --git a/tests/flytekit/loadtests/cp_spark.py b/tests/flytekit/loadtests/cp_spark.py index 7822a3bb2f..b513c028f8 100644 --- a/tests/flytekit/loadtests/cp_spark.py +++ b/tests/flytekit/loadtests/cp_spark.py @@ -1,13 +1,11 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function import random from operator import add from six.moves import range -from flytekit.sdk.tasks import spark_task, inputs, outputs +from flytekit.sdk.tasks import inputs, outputs, spark_task from flytekit.sdk.types import Types from flytekit.sdk.workflow import workflow_class @@ -16,14 +14,15 @@ @outputs(out=Types.Float) @spark_task( spark_conf={ - 'spark.driver.memory': "600M", - 'spark.executor.memory': "600M", - 'spark.executor.cores': '1', - 'spark.executor.instances': '1', - 'spark.hadoop.mapred.output.committer.class': "org.apache.hadoop.mapred.DirectFileOutputCommitter", - 'spark.hadoop.mapreduce.use.directfileoutputcommitter': "true", + "spark.driver.memory": "600M", + "spark.executor.memory": "600M", + "spark.executor.cores": "1", + "spark.executor.instances": "1", + "spark.hadoop.mapred.output.committer.class": "org.apache.hadoop.mapred.DirectFileOutputCommitter", + "spark.hadoop.mapreduce.use.directfileoutputcommitter": "true", }, - cache_version='1') + cache_version="1", +) def hello_spark(workflow_parameters, spark_context, partitions, out): print("Starting Spark with Partitions: {}".format(partitions)) diff --git a/tests/flytekit/loadtests/dynamic_job.py b/tests/flytekit/loadtests/dynamic_job.py index 70f75269e4..8d5fc4f06e 100644 --- a/tests/flytekit/loadtests/dynamic_job.py +++ b/tests/flytekit/loadtests/dynamic_job.py @@ -2,28 +2,29 @@ import time -from flytekit.sdk.tasks import python_task, dynamic_task, inputs, outputs -from flytekit.sdk.types import Types -from flytekit.sdk.workflow import workflow_class, Input from six.moves import range +from flytekit.sdk.tasks import dynamic_task, inputs, outputs, python_task +from flytekit.sdk.types import Types +from flytekit.sdk.workflow import Input, workflow_class + @inputs(value1=Types.Integer) @outputs(out=Types.Integer) @python_task(cpu_request="1", cpu_limit="1", memory_request="5G") def dynamic_sub_task(workflow_parameters, value1, out): - for i in range(11*60): + for i in range(11 * 60): print("This is load test task. I have been running for {} seconds.".format(i)) time.sleep(1) - output = value1*2 + output = value1 * 2 print("Output: {}".format(output)) out.set(output) @inputs(tasks_count=Types.Integer) @outputs(out=[Types.Integer]) -@dynamic_task(cache_version='1') +@dynamic_task(cache_version="1") def dynamic_task(workflow_parameters, tasks_count, out): res = [] for i in range(0, tasks_count): diff --git a/tests/flytekit/loadtests/orchestrator.py b/tests/flytekit/loadtests/orchestrator.py index 4e7d46583e..c976171ef3 100644 --- a/tests/flytekit/loadtests/orchestrator.py +++ b/tests/flytekit/loadtests/orchestrator.py @@ -1,13 +1,13 @@ from __future__ import absolute_import, division, print_function +from six.moves import range + from flytekit.sdk.types import Types -from flytekit.sdk.workflow import workflow_class, Input -from tests.flytekit.loadtests.python import FlytePythonLoadTestWorkflow +from flytekit.sdk.workflow import Input, workflow_class +from tests.flytekit.loadtests.dynamic_job import FlyteDJOLoadTestWorkflow from tests.flytekit.loadtests.hive import FlyteHiveLoadTestWorkflow +from tests.flytekit.loadtests.python import FlytePythonLoadTestWorkflow from tests.flytekit.loadtests.spark import FlyteSparkLoadTestWorkflow -from tests.flytekit.loadtests.dynamic_job import FlyteDJOLoadTestWorkflow - -from six.moves import range # launch plans for individual load tests. python_loadtest_lp = FlytePythonLoadTestWorkflow.create_launch_plan() diff --git a/tests/flytekit/loadtests/python.py b/tests/flytekit/loadtests/python.py index b40bfca0a3..332b9846f5 100644 --- a/tests/flytekit/loadtests/python.py +++ b/tests/flytekit/loadtests/python.py @@ -2,17 +2,18 @@ import time -from flytekit.sdk.tasks import python_task, inputs, outputs +from six.moves import range + +from flytekit.sdk.tasks import inputs, outputs, python_task from flytekit.sdk.types import Types from flytekit.sdk.workflow import workflow_class -from six.moves import range @inputs(value1_to_add=Types.Integer, value2_to_add=Types.Integer) @outputs(out=Types.Integer) @python_task(cpu_request="5", cpu_limit="5", memory_request="32G") def sum_and_print(workflow_parameters, value1_to_add, value2_to_add, out): - for i in range(11*60): + for i in range(11 * 60): print("This is load test task. I have been running for {} seconds.".format(i)) time.sleep(1) @@ -25,7 +26,4 @@ def sum_and_print(workflow_parameters, value1_to_add, value2_to_add, out): class FlytePythonLoadTestWorkflow(object): print_sum = [None] * 30 for i in range(0, 30): - print_sum[i] = sum_and_print( - value1_to_add=1, - value2_to_add=1 - ) + print_sum[i] = sum_and_print(value1_to_add=1, value2_to_add=1) diff --git a/tests/flytekit/unit/bin/test_python_entrypoint.py b/tests/flytekit/unit/bin/test_python_entrypoint.py index 6865fdf3d0..ea2ac10568 100644 --- a/tests/flytekit/unit/bin/test_python_entrypoint.py +++ b/tests/flytekit/unit/bin/test_python_entrypoint.py @@ -6,31 +6,29 @@ from click.testing import CliRunner from flyteidl.core import literals_pb2 as _literals_pb2 -from flytekit.bin.entrypoint import execute_task_cmd, _execute_task -from flytekit.common import utils as _utils, constants as _constants +from flytekit.bin.entrypoint import _execute_task, execute_task_cmd +from flytekit.common import constants as _constants +from flytekit.common import utils as _utils from flytekit.common.types import helpers as _type_helpers -from flytekit.models import literals as _literals from flytekit.configuration import TemporaryConfiguration as _TemporaryConfiguration from flytekit.models import literals as _literal_models +from flytekit.models import literals as _literals from tests.flytekit.common import task_definitions as _task_defs def _type_map_from_variable_map(variable_map): - return { - k: _type_helpers.get_sdk_type_from_literal_type(v.type) for k, v in - six.iteritems(variable_map) - } + return {k: _type_helpers.get_sdk_type_from_literal_type(v.type) for k, v in six.iteritems(variable_map)} def test_single_step_entrypoint_in_proc(): - with _TemporaryConfiguration(os.path.join(os.path.dirname(__file__), 'fake.config'), - internal_overrides={ - 'project': 'test', - 'domain': 'development' - }): + with _TemporaryConfiguration( + os.path.join(os.path.dirname(__file__), "fake.config"), + internal_overrides={"project": "test", "domain": "development"}, + ): with _utils.AutoDeletingTempDir("in") as input_dir: literal_map = _type_helpers.pack_python_std_map_to_literal_map( - {'a': 9}, _type_map_from_variable_map(_task_defs.add_one.interface.inputs)) + {"a": 9}, _type_map_from_variable_map(_task_defs.add_one.interface.inputs), + ) input_file = os.path.join(input_dir.name, "inputs.pb") _utils.write_proto_to_file(literal_map.to_flyte_idl(), input_file) @@ -40,30 +38,29 @@ def test_single_step_entrypoint_in_proc(): _task_defs.add_one.task_function_name, input_file, output_dir.name, - False + False, ) p = _utils.load_proto_from_file( - _literals_pb2.LiteralMap, - os.path.join(output_dir.name, _constants.OUTPUT_FILE_NAME) + _literals_pb2.LiteralMap, os.path.join(output_dir.name, _constants.OUTPUT_FILE_NAME), ) raw_map = _type_helpers.unpack_literal_map_to_sdk_python_std( _literal_models.LiteralMap.from_flyte_idl(p), - _type_map_from_variable_map(_task_defs.add_one.interface.outputs) + _type_map_from_variable_map(_task_defs.add_one.interface.outputs), ) - assert raw_map['b'] == 10 + assert raw_map["b"] == 10 assert len(raw_map) == 1 def test_single_step_entrypoint_out_of_proc(): - with _TemporaryConfiguration(os.path.join(os.path.dirname(__file__), 'fake.config'), - internal_overrides={ - 'project': 'test', - 'domain': 'development' - }): + with _TemporaryConfiguration( + os.path.join(os.path.dirname(__file__), "fake.config"), + internal_overrides={"project": "test", "domain": "development"}, + ): with _utils.AutoDeletingTempDir("in") as input_dir: - literal_map = _type_helpers.pack_python_std_map_to_literal_map({'a': 9}, _type_map_from_variable_map( - _task_defs.add_one.interface.inputs)) + literal_map = _type_helpers.pack_python_std_map_to_literal_map( + {"a": 9}, _type_map_from_variable_map(_task_defs.add_one.interface.inputs), + ) input_file = os.path.join(input_dir.name, "inputs.pb") _utils.write_proto_to_file(literal_map.to_flyte_idl(), input_file) @@ -77,29 +74,28 @@ def test_single_step_entrypoint_out_of_proc(): assert result.exit_code == 0 p = _utils.load_proto_from_file( - _literals_pb2.LiteralMap, - os.path.join(output_dir.name, _constants.OUTPUT_FILE_NAME) + _literals_pb2.LiteralMap, os.path.join(output_dir.name, _constants.OUTPUT_FILE_NAME), ) raw_map = _type_helpers.unpack_literal_map_to_sdk_python_std( _literal_models.LiteralMap.from_flyte_idl(p), - _type_map_from_variable_map(_task_defs.add_one.interface.outputs) + _type_map_from_variable_map(_task_defs.add_one.interface.outputs), ) - assert raw_map['b'] == 10 + assert raw_map["b"] == 10 assert len(raw_map) == 1 def test_arrayjob_entrypoint_in_proc(): - with _TemporaryConfiguration(os.path.join(os.path.dirname(__file__), 'fake.config'), - internal_overrides={ - 'project': 'test', - 'domain': 'development' - }): + with _TemporaryConfiguration( + os.path.join(os.path.dirname(__file__), "fake.config"), + internal_overrides={"project": "test", "domain": "development"}, + ): with _utils.AutoDeletingTempDir("dir") as dir: literal_map = _type_helpers.pack_python_std_map_to_literal_map( - {'a': 9}, _type_map_from_variable_map(_task_defs.add_one.interface.inputs)) + {"a": 9}, _type_map_from_variable_map(_task_defs.add_one.interface.inputs), + ) input_dir = os.path.join(dir.name, "1") - os.mkdir(input_dir) # auto cleanup will take this subdir into account + os.mkdir(input_dir) # auto cleanup will take this subdir into account input_file = os.path.join(input_dir, "inputs.pb") _utils.write_proto_to_file(literal_map.to_flyte_idl(), input_file) @@ -111,33 +107,28 @@ def test_arrayjob_entrypoint_in_proc(): _utils.write_proto_to_file(index_lookup_collection.to_flyte_idl(), index_lookup_file) # fake arrayjob task by setting environment variables - orig_env_index_var_name = os.environ.get('BATCH_JOB_ARRAY_INDEX_VAR_NAME') - orig_env_array_index = os.environ.get('AWS_BATCH_JOB_ARRAY_INDEX') - os.environ['BATCH_JOB_ARRAY_INDEX_VAR_NAME'] = 'AWS_BATCH_JOB_ARRAY_INDEX' - os.environ['AWS_BATCH_JOB_ARRAY_INDEX'] = '0' + orig_env_index_var_name = os.environ.get("BATCH_JOB_ARRAY_INDEX_VAR_NAME") + orig_env_array_index = os.environ.get("AWS_BATCH_JOB_ARRAY_INDEX") + os.environ["BATCH_JOB_ARRAY_INDEX_VAR_NAME"] = "AWS_BATCH_JOB_ARRAY_INDEX" + os.environ["AWS_BATCH_JOB_ARRAY_INDEX"] = "0" _execute_task( - _task_defs.add_one.task_module, - _task_defs.add_one.task_function_name, - dir.name, - dir.name, - False + _task_defs.add_one.task_module, _task_defs.add_one.task_function_name, dir.name, dir.name, False, ) raw_map = _type_helpers.unpack_literal_map_to_sdk_python_std( _literal_models.LiteralMap.from_flyte_idl( _utils.load_proto_from_file( - _literals_pb2.LiteralMap, - os.path.join(input_dir, _constants.OUTPUT_FILE_NAME) + _literals_pb2.LiteralMap, os.path.join(input_dir, _constants.OUTPUT_FILE_NAME), ) ), - _type_map_from_variable_map(_task_defs.add_one.interface.outputs) + _type_map_from_variable_map(_task_defs.add_one.interface.outputs), ) - assert raw_map['b'] == 10 + assert raw_map["b"] == 10 assert len(raw_map) == 1 # reset the env vars if orig_env_index_var_name: - os.environ['BATCH_JOB_ARRAY_INDEX_VAR_NAME'] = orig_env_index_var_name + os.environ["BATCH_JOB_ARRAY_INDEX_VAR_NAME"] = orig_env_index_var_name if orig_env_array_index: - os.environ['AWS_BATCH_JOB_ARRAY_INDEX'] = orig_env_array_index + os.environ["AWS_BATCH_JOB_ARRAY_INDEX"] = orig_env_array_index diff --git a/tests/flytekit/unit/cli/auth/test_auth.py b/tests/flytekit/unit/cli/auth/test_auth.py index 757e6f4797..a3fba13d7b 100644 --- a/tests/flytekit/unit/cli/auth/test_auth.py +++ b/tests/flytekit/unit/cli/auth/test_auth.py @@ -1,10 +1,10 @@ from __future__ import absolute_import import re +from multiprocessing import Queue as _Queue from flytekit.clis.auth import auth as _auth -from multiprocessing import Queue as _Queue try: # Python 3 import http.server as _BaseHTTPServer except ImportError: # Python 2 @@ -15,12 +15,12 @@ def test_generate_code_verifier(): verifier = _auth._generate_code_verifier() assert verifier is not None assert 43 < len(verifier) < 128 - assert not re.search(r'[^a-zA-Z0-9_\-.~]+', verifier) + assert not re.search(r"[^a-zA-Z0-9_\-.~]+", verifier) def test_generate_state_parameter(): param = _auth._generate_state_parameter() - assert not re.search(r'[^a-zA-Z0-9-_.,]+', param) + assert not re.search(r"[^a-zA-Z0-9-_.,]+", param) def test_create_code_challenge(): @@ -35,4 +35,3 @@ def test_oauth_http_server(): server.handle_authorization_code(test_auth_code) auth_code = queue.get() assert test_auth_code == auth_code - diff --git a/tests/flytekit/unit/cli/auth/test_credentials.py b/tests/flytekit/unit/cli/auth/test_credentials.py index 2642e65449..7a5805ee44 100644 --- a/tests/flytekit/unit/cli/auth/test_credentials.py +++ b/tests/flytekit/unit/cli/auth/test_credentials.py @@ -4,38 +4,38 @@ def test_get_discovery_endpoint(): - endpoint = _credentials._get_discovery_endpoint('//localhost:8088', 'localhost:8089', True) - assert endpoint == 'http://localhost:8088/.well-known/oauth-authorization-server' + endpoint = _credentials._get_discovery_endpoint("//localhost:8088", "localhost:8089", True) + assert endpoint == "http://localhost:8088/.well-known/oauth-authorization-server" - endpoint = _credentials._get_discovery_endpoint('//localhost:8088', 'localhost:8089', False) - assert endpoint == 'https://localhost:8088/.well-known/oauth-authorization-server' + endpoint = _credentials._get_discovery_endpoint("//localhost:8088", "localhost:8089", False) + assert endpoint == "https://localhost:8088/.well-known/oauth-authorization-server" - endpoint = _credentials._get_discovery_endpoint('//localhost:8088/path', 'localhost:8089', True) - assert endpoint == 'http://localhost:8088/path/.well-known/oauth-authorization-server' + endpoint = _credentials._get_discovery_endpoint("//localhost:8088/path", "localhost:8089", True) + assert endpoint == "http://localhost:8088/path/.well-known/oauth-authorization-server" - endpoint = _credentials._get_discovery_endpoint('//localhost:8088/path', 'localhost:8089', False) - assert endpoint == 'https://localhost:8088/path/.well-known/oauth-authorization-server' + endpoint = _credentials._get_discovery_endpoint("//localhost:8088/path", "localhost:8089", False) + assert endpoint == "https://localhost:8088/path/.well-known/oauth-authorization-server" - endpoint = _credentials._get_discovery_endpoint('//flyte.corp.com', 'localhost:8089', False) - assert endpoint == 'https://flyte.corp.com/.well-known/oauth-authorization-server' + endpoint = _credentials._get_discovery_endpoint("//flyte.corp.com", "localhost:8089", False) + assert endpoint == "https://flyte.corp.com/.well-known/oauth-authorization-server" - endpoint = _credentials._get_discovery_endpoint('//flyte.corp.com/path', 'localhost:8089', False) - assert endpoint == 'https://flyte.corp.com/path/.well-known/oauth-authorization-server' + endpoint = _credentials._get_discovery_endpoint("//flyte.corp.com/path", "localhost:8089", False) + assert endpoint == "https://flyte.corp.com/path/.well-known/oauth-authorization-server" - endpoint = _credentials._get_discovery_endpoint(None, 'localhost:8089', True) - assert endpoint == 'http://localhost:8089/.well-known/oauth-authorization-server' + endpoint = _credentials._get_discovery_endpoint(None, "localhost:8089", True) + assert endpoint == "http://localhost:8089/.well-known/oauth-authorization-server" - endpoint = _credentials._get_discovery_endpoint(None, 'localhost:8089', False) - assert endpoint == 'https://localhost:8089/.well-known/oauth-authorization-server' + endpoint = _credentials._get_discovery_endpoint(None, "localhost:8089", False) + assert endpoint == "https://localhost:8089/.well-known/oauth-authorization-server" - endpoint = _credentials._get_discovery_endpoint(None, 'flyte.corp.com', True) - assert endpoint == 'http://flyte.corp.com/.well-known/oauth-authorization-server' + endpoint = _credentials._get_discovery_endpoint(None, "flyte.corp.com", True) + assert endpoint == "http://flyte.corp.com/.well-known/oauth-authorization-server" - endpoint = _credentials._get_discovery_endpoint(None, 'flyte.corp.com', False) - assert endpoint == 'https://flyte.corp.com/.well-known/oauth-authorization-server' + endpoint = _credentials._get_discovery_endpoint(None, "flyte.corp.com", False) + assert endpoint == "https://flyte.corp.com/.well-known/oauth-authorization-server" - endpoint = _credentials._get_discovery_endpoint(None, 'localhost:8089', True) - assert endpoint == 'http://localhost:8089/.well-known/oauth-authorization-server' + endpoint = _credentials._get_discovery_endpoint(None, "localhost:8089", True) + assert endpoint == "http://localhost:8089/.well-known/oauth-authorization-server" - endpoint = _credentials._get_discovery_endpoint(None, 'localhost:8089', False) - assert endpoint == 'https://localhost:8089/.well-known/oauth-authorization-server' + endpoint = _credentials._get_discovery_endpoint(None, "localhost:8089", False) + assert endpoint == "https://localhost:8089/.well-known/oauth-authorization-server" diff --git a/tests/flytekit/unit/cli/auth/test_discovery.py b/tests/flytekit/unit/cli/auth/test_discovery.py index f4390ee70b..5813d18bf0 100644 --- a/tests/flytekit/unit/cli/auth/test_discovery.py +++ b/tests/flytekit/unit/cli/auth/test_discovery.py @@ -10,9 +10,9 @@ def test_get_authorization_endpoints(): auth_endpoint = "http://flyte-admin.com/authorization" token_endpoint = "http://flyte-admin.com/token" - responses.add(responses.GET, discovery_url, - json={'authorization_endpoint': auth_endpoint, - 'token_endpoint': token_endpoint}) + responses.add( + responses.GET, discovery_url, json={"authorization_endpoint": auth_endpoint, "token_endpoint": token_endpoint}, + ) discovery_client = _discovery.DiscoveryClient(discovery_url=discovery_url) assert discovery_client.get_authorization_endpoints().auth_endpoint == auth_endpoint @@ -25,9 +25,9 @@ def test_get_authorization_endpoints_relative(): auth_endpoint = "/authorization" token_endpoint = "/token" - responses.add(responses.GET, discovery_url, - json={'authorization_endpoint': auth_endpoint, - 'token_endpoint': token_endpoint}) + responses.add( + responses.GET, discovery_url, json={"authorization_endpoint": auth_endpoint, "token_endpoint": token_endpoint}, + ) discovery_client = _discovery.DiscoveryClient(discovery_url=discovery_url) assert discovery_client.get_authorization_endpoints().auth_endpoint == "http://flyte-admin.com/authorization" @@ -37,7 +37,9 @@ def test_get_authorization_endpoints_relative(): @responses.activate def test_get_authorization_endpoints_missing_authorization_endpoint(): discovery_url = "http://flyte-admin.com/discovery" - responses.add(responses.GET, discovery_url, json={'token_endpoint': "http://flyte-admin.com/token"}) + responses.add( + responses.GET, discovery_url, json={"token_endpoint": "http://flyte-admin.com/token"}, + ) discovery_client = _discovery.DiscoveryClient(discovery_url=discovery_url) with pytest.raises(Exception): @@ -47,7 +49,9 @@ def test_get_authorization_endpoints_missing_authorization_endpoint(): @responses.activate def test_get_authorization_endpoints_missing_token_endpoint(): discovery_url = "http://flyte-admin.com/discovery" - responses.add(responses.GET, discovery_url, json={'authorization_endpoint': "http://flyte-admin.com/authorization"}) + responses.add( + responses.GET, discovery_url, json={"authorization_endpoint": "http://flyte-admin.com/authorization"}, + ) discovery_client = _discovery.DiscoveryClient(discovery_url=discovery_url) with pytest.raises(Exception): diff --git a/tests/flytekit/unit/cli/pyflyte/conftest.py b/tests/flytekit/unit/cli/pyflyte/conftest.py index 9fe86bd645..4dc1b10d12 100644 --- a/tests/flytekit/unit/cli/pyflyte/conftest.py +++ b/tests/flytekit/unit/cli/pyflyte/conftest.py @@ -1,39 +1,45 @@ from __future__ import absolute_import + +import os +import sys + +import mock as _mock +import pytest +from click.testing import CliRunner + from flytekit import configuration as _config from flytekit.clis.sdk_in_container import constants as _constants from flytekit.clis.sdk_in_container import pyflyte as _pyflyte from flytekit.tools import module_loader as _module_loader -from click.testing import CliRunner -import mock as _mock -import pytest -import os -import sys def _fake_module_load(names): - assert names == ('common.workflows',) + assert names == ("common.workflows",) from common.workflows import simple + yield simple -@pytest.yield_fixture(scope='function', - params=[ - os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../../common/configs/local.config'), - '/foo/bar', - None - ]) +@pytest.yield_fixture( + scope="function", + params=[ + os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../../common/configs/local.config",), + "/foo/bar", + None, + ], +) def mock_ctx(request): with _config.TemporaryConfiguration(request.param): - sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../..')) + sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../..")) try: - with _mock.patch('flytekit.tools.module_loader.iterate_modules') as mock_module_load: + with _mock.patch("flytekit.tools.module_loader.iterate_modules") as mock_module_load: mock_module_load.side_effect = _fake_module_load ctx = _mock.MagicMock() ctx.obj = { - _constants.CTX_PACKAGES: ('common.workflows',), - _constants.CTX_PROJECT: 'tests', - _constants.CTX_DOMAIN: 'unit', - _constants.CTX_VERSION: 'version' + _constants.CTX_PACKAGES: ("common.workflows",), + _constants.CTX_PROJECT: "tests", + _constants.CTX_DOMAIN: "unit", + _constants.CTX_VERSION: "version", } yield ctx finally: @@ -45,10 +51,14 @@ def mock_clirunner(monkeypatch): def f(*args, **kwargs): runner = CliRunner() base_args = [ - '-p', 'tests', - '-d', 'unit', - '-v', 'version', - '--pkgs', 'common.workflows', + "-p", + "tests", + "-d", + "unit", + "-v", + "version", + "--pkgs", + "common.workflows", ] result = runner.invoke(_pyflyte.main, base_args + list(args), **kwargs) @@ -58,9 +68,9 @@ def f(*args, **kwargs): return result - tests_dir_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../..') - config_path = os.path.join(tests_dir_path, 'common/configs/local.config') + tests_dir_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../..") + config_path = os.path.join(tests_dir_path, "common/configs/local.config") with _config.TemporaryConfiguration(config_path): monkeypatch.syspath_prepend(tests_dir_path) - monkeypatch.setattr(_module_loader, 'iterate_modules', _fake_module_load) + monkeypatch.setattr(_module_loader, "iterate_modules", _fake_module_load) yield f diff --git a/tests/flytekit/unit/cli/pyflyte/test_basic_auth.py b/tests/flytekit/unit/cli/pyflyte/test_basic_auth.py index 7c4436113b..421713ee16 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_basic_auth.py +++ b/tests/flytekit/unit/cli/pyflyte/test_basic_auth.py @@ -1,18 +1,19 @@ from __future__ import absolute_import + import json -from mock import MagicMock, patch, PropertyMock +from mock import MagicMock, patch + from flytekit.clis.flyte_cli.main import _welcome_message from flytekit.clis.sdk_in_container import basic_auth -from flytekit.configuration.creds import ( - CLIENT_CREDENTIALS_SECRET as _CREDENTIALS_SECRET, -) +from flytekit.configuration.creds import CLIENT_CREDENTIALS_SECRET as _CREDENTIALS_SECRET _welcome_message() def test_get_secret(): import os + os.environ[_CREDENTIALS_SECRET.env_var] = "abc" assert basic_auth.get_secret() == "abc" @@ -22,7 +23,7 @@ def test_get_basic_authorization_header(): assert header == "Basic Y2xpZW50X2lkOmFiYw==" -@patch('flytekit.clis.sdk_in_container.basic_auth._requests') +@patch("flytekit.clis.sdk_in_container.basic_auth._requests") def test_get_token(mock_requests): response = MagicMock() response.status_code = 200 diff --git a/tests/flytekit/unit/cli/pyflyte/test_launch_plans.py b/tests/flytekit/unit/cli/pyflyte/test_launch_plans.py index 993d75eebd..6669ce350c 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_launch_plans.py +++ b/tests/flytekit/unit/cli/pyflyte/test_launch_plans.py @@ -1,30 +1,32 @@ from __future__ import absolute_import + +import pytest + from flytekit.clis.sdk_in_container import launch_plan from flytekit.clis.sdk_in_container.launch_plan import launch_plans -import pytest def test_list_commands(mock_ctx): - g = launch_plan.LaunchPlanExecuteGroup('test_group') + g = launch_plan.LaunchPlanExecuteGroup("test_group") v = g.list_commands(mock_ctx) - assert v == ['common.workflows.simple.SimpleWorkflow'] + assert v == ["common.workflows.simple.SimpleWorkflow"] def test_get_commands(mock_ctx): - g = launch_plan.LaunchPlanExecuteGroup('test_group') - v = g.get_command(mock_ctx, 'common.workflows.simple.SimpleWorkflow') - assert v.params[0].human_readable_name == 'input_1' - assert 'INTEGER' in v.params[0].help - assert v.params[1].human_readable_name == 'input_2' - assert 'INTEGER' in v.params[1].help - assert 'Not required.' in v.params[1].help + g = launch_plan.LaunchPlanExecuteGroup("test_group") + v = g.get_command(mock_ctx, "common.workflows.simple.SimpleWorkflow") + assert v.params[0].human_readable_name == "input_1" + assert "INTEGER" in v.params[0].help + assert v.params[1].human_readable_name == "input_2" + assert "INTEGER" in v.params[1].help + assert "Not required." in v.params[1].help with pytest.raises(Exception): - g.get_command(mock_ctx, 'common.workflows.simple.DoesNotExist') + g.get_command(mock_ctx, "common.workflows.simple.DoesNotExist") with pytest.raises(Exception): - g.get_command(mock_ctx, 'does.not.exist') + g.get_command(mock_ctx, "does.not.exist") def test_launch_plans_commands(mock_ctx): command_names = [c for c in launch_plans.list_commands(mock_ctx)] - assert command_names == sorted(['execute', 'activate-all', 'activate-all-schedules']) + assert command_names == sorted(["execute", "activate-all", "activate-all-schedules"]) diff --git a/tests/flytekit/unit/cli/pyflyte/test_register.py b/tests/flytekit/unit/cli/pyflyte/test_register.py index b9f5388ae8..a16fc8e59a 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_register.py +++ b/tests/flytekit/unit/cli/pyflyte/test_register.py @@ -1,20 +1,21 @@ -from click.testing import CliRunner -import pytest -from mock import MagicMock, PropertyMock, patch +from mock import MagicMock -from flytekit.clis.sdk_in_container import constants as _constants from flytekit.engines.flyte import engine def test_register_workflows(mock_clirunner, monkeypatch): mock_get_task = MagicMock() - monkeypatch.setattr(engine.FlyteEngineFactory, 'get_task', MagicMock(return_value=mock_get_task)) + monkeypatch.setattr(engine.FlyteEngineFactory, "get_task", MagicMock(return_value=mock_get_task)) mock_get_workflow = MagicMock() - monkeypatch.setattr(engine.FlyteEngineFactory, 'get_workflow', MagicMock(return_value=mock_get_workflow)) + monkeypatch.setattr( + engine.FlyteEngineFactory, "get_workflow", MagicMock(return_value=mock_get_workflow), + ) mock_get_launch_plan = MagicMock() - monkeypatch.setattr(engine.FlyteEngineFactory, 'get_launch_plan', MagicMock(return_value=mock_get_launch_plan)) + monkeypatch.setattr( + engine.FlyteEngineFactory, "get_launch_plan", MagicMock(return_value=mock_get_launch_plan), + ) - result = mock_clirunner('register', 'workflows') + result = mock_clirunner("register", "workflows") assert result.exit_code == 0 @@ -28,13 +29,17 @@ def test_register_workflows(mock_clirunner, monkeypatch): def test_register_workflows_with_test_switch(mock_clirunner, monkeypatch): mock_get_task = MagicMock() - monkeypatch.setattr(engine.FlyteEngineFactory, 'get_task', MagicMock(return_value=mock_get_task)) + monkeypatch.setattr(engine.FlyteEngineFactory, "get_task", MagicMock(return_value=mock_get_task)) mock_get_workflow = MagicMock() - monkeypatch.setattr(engine.FlyteEngineFactory, 'get_workflow', MagicMock(return_value=mock_get_workflow)) + monkeypatch.setattr( + engine.FlyteEngineFactory, "get_workflow", MagicMock(return_value=mock_get_workflow), + ) mock_get_launch_plan = MagicMock() - monkeypatch.setattr(engine.FlyteEngineFactory, 'get_launch_plan', MagicMock(return_value=mock_get_launch_plan)) + monkeypatch.setattr( + engine.FlyteEngineFactory, "get_launch_plan", MagicMock(return_value=mock_get_launch_plan), + ) - result = mock_clirunner('register', '--test', 'workflows') + result = mock_clirunner("register", "--test", "workflows") assert result.exit_code == 0 diff --git a/tests/flytekit/unit/cli/test_flyte_cli.py b/tests/flytekit/unit/cli/test_flyte_cli.py index 530d8c03f7..adcd07dff9 100644 --- a/tests/flytekit/unit/cli/test_flyte_cli.py +++ b/tests/flytekit/unit/cli/test_flyte_cli.py @@ -4,20 +4,21 @@ import pytest from flytekit.clis.flyte_cli import main as _main +from flytekit.common.exceptions.user import FlyteAssertion from flytekit.common.types import primitives from flytekit.configuration import TemporaryConfiguration from flytekit.models.core import identifier as _core_identifier -from flytekit.sdk.tasks import python_task, inputs, outputs -from flytekit.common.exceptions.user import FlyteAssertion +from flytekit.sdk.tasks import inputs, outputs, python_task mm = _mock.MagicMock() -mm.return_value=100 +mm.return_value = 100 def get_sample_task(): """ :rtype: flytekit.common.tasks.task.SdkTask """ + @inputs(a=primitives.Integer) @outputs(b=primitives.Integer) @python_task() @@ -27,40 +28,37 @@ def my_task(wf_params, a, b): return my_task -@_mock.patch('flytekit.clis.flyte_cli.main._load_proto_from_file') +@_mock.patch("flytekit.clis.flyte_cli.main._load_proto_from_file") def test__extract_files(load_mock): - id = _core_identifier.Identifier(_core_identifier.ResourceType.TASK, 'myproject', 'development', 'name', 'v') + id = _core_identifier.Identifier(_core_identifier.ResourceType.TASK, "myproject", "development", "name", "v") t = get_sample_task() with TemporaryConfiguration( - "", - internal_overrides={ - 'image': 'myflyteimage:v123', - 'project': 'myflyteproject', - 'domain': 'development' - } + "", internal_overrides={"image": "myflyteimage:v123", "project": "myflyteproject", "domain": "development"}, ): task_spec = t.serialize() load_mock.side_effect = [id.to_flyte_idl(), task_spec] - new_id, entity = _main._extract_pair('a', 'b') + new_id, entity = _main._extract_pair("a", "b") assert new_id == id.to_flyte_idl() assert task_spec == entity -@_mock.patch('flytekit.clis.flyte_cli.main._load_proto_from_file') -def test__extract_files(load_mock): - id = _core_identifier.Identifier(_core_identifier.ResourceType.UNSPECIFIED, 'myproject', 'development', 'name', 'v') +@_mock.patch("flytekit.clis.flyte_cli.main._load_proto_from_file") +def test__extract_files_with_unspecified_resource_type(load_mock): + id = _core_identifier.Identifier( + _core_identifier.ResourceType.UNSPECIFIED, "myproject", "development", "name", "v", + ) load_mock.return_value = id.to_flyte_idl() with pytest.raises(FlyteAssertion): - _main._extract_pair('a', 'b') + _main._extract_pair("a", "b") def _identity_dummy(a, b): return (a, b) -@_mock.patch('flytekit.clis.flyte_cli.main._extract_pair', new=_identity_dummy) -def test__extract_files(): - results = _main._extract_files([1,2,3,4]) +@_mock.patch("flytekit.clis.flyte_cli.main._extract_pair", new=_identity_dummy) +def test__extract_files_pair_iterator(): + results = _main._extract_files([1, 2, 3, 4]) assert [(1, 2), (3, 4)] == results diff --git a/tests/flytekit/unit/cli/test_helpers.py b/tests/flytekit/unit/cli/test_helpers.py index f6c9b590a8..3ed7e05daf 100644 --- a/tests/flytekit/unit/cli/test_helpers.py +++ b/tests/flytekit/unit/cli/test_helpers.py @@ -3,21 +3,20 @@ import pytest from flytekit.clis import helpers -from flytekit.models import literals -from flytekit.models import types -from flytekit.models.interface import Variable, Parameter, ParameterMap +from flytekit.models import literals, types +from flytekit.models.interface import Parameter, ParameterMap, Variable def test_parse_args_into_dict(): - sample_args1 = (u'input_b=mystr', u'input_c=18') - sample_args2 = ('input_a=mystr===d',) + sample_args1 = (u"input_b=mystr", u"input_c=18") + sample_args2 = ("input_a=mystr===d",) sample_args3 = () output = helpers.parse_args_into_dict(sample_args1) - assert output['input_b'] == 'mystr' - assert output['input_c'] == '18' + assert output["input_b"] == "mystr" + assert output["input_c"] == "18" output = helpers.parse_args_into_dict(sample_args2) - assert output['input_a'] == 'mystr===d' + assert output["input_a"] == "mystr===d" output = helpers.parse_args_into_dict(sample_args3) assert output == {} @@ -26,13 +25,13 @@ def test_parse_args_into_dict(): def test_construct_literal_map_from_variable_map(): v = Variable(type=types.LiteralType(simple=types.SimpleType.INTEGER), description="some description") variable_map = { - 'inputa': v, + "inputa": v, } - input_txt_dictionary = {'inputa': '15'} + input_txt_dictionary = {"inputa": "15"} literal_map = helpers.construct_literal_map_from_variable_map(variable_map, input_txt_dictionary) - parsed_literal = literal_map.literals['inputa'].value + parsed_literal = literal_map.literals["inputa"].value ll = literals.Scalar(primitive=literals.Primitive(integer=15)) assert parsed_literal == ll @@ -40,14 +39,12 @@ def test_construct_literal_map_from_variable_map(): def test_construct_literal_map_from_parameter_map(): v = Variable(type=types.LiteralType(simple=types.SimpleType.INTEGER), description="some description") p = Parameter(var=v, required=True) - pm = ParameterMap(parameters={ - 'inputa': p, - }) + pm = ParameterMap(parameters={"inputa": p}) - input_txt_dictionary = {'inputa': '15'} + input_txt_dictionary = {"inputa": "15"} literal_map = helpers.construct_literal_map_from_parameter_map(pm, input_txt_dictionary) - parsed_literal = literal_map.literals['inputa'].value + parsed_literal = literal_map.literals["inputa"].value ll = literals.Scalar(primitive=literals.Primitive(integer=15)) assert parsed_literal == ll @@ -56,10 +53,10 @@ def test_construct_literal_map_from_parameter_map(): def test_strtobool(): - assert not helpers.str2bool('False') - assert not helpers.str2bool('OFF') - assert not helpers.str2bool('no') - assert not helpers.str2bool('0') - assert helpers.str2bool('t') - assert helpers.str2bool('true') - assert helpers.str2bool('stuff') + assert not helpers.str2bool("False") + assert not helpers.str2bool("OFF") + assert not helpers.str2bool("no") + assert not helpers.str2bool("0") + assert helpers.str2bool("t") + assert helpers.str2bool("true") + assert helpers.str2bool("stuff") diff --git a/tests/flytekit/unit/clients/test_raw.py b/tests/flytekit/unit/clients/test_raw.py index 6a0fdcc1dc..e42ab5136f 100644 --- a/tests/flytekit/unit/clients/test_raw.py +++ b/tests/flytekit/unit/clients/test_raw.py @@ -1,29 +1,30 @@ from __future__ import absolute_import -from flytekit.clients.raw import ( - RawSynchronousFlyteClient as _RawSynchronousFlyteClient, - _refresh_credentials_basic -) -from flytekit.configuration.creds import CLIENT_CREDENTIALS_SECRET as _CREDENTIALS_SECRET -from flytekit.clis.auth.discovery import AuthorizationEndpoints as _AuthorizationEndpoints -import mock -import os + import json +import os + +import mock + +from flytekit.clients.raw import RawSynchronousFlyteClient as _RawSynchronousFlyteClient +from flytekit.clients.raw import _refresh_credentials_basic +from flytekit.clis.auth.discovery import AuthorizationEndpoints as _AuthorizationEndpoints +from flytekit.configuration.creds import CLIENT_CREDENTIALS_SECRET as _CREDENTIALS_SECRET -@mock.patch('flytekit.clients.raw._admin_service') -@mock.patch('flytekit.clients.raw._insecure_channel') +@mock.patch("flytekit.clients.raw._admin_service") +@mock.patch("flytekit.clients.raw._insecure_channel") def test_client_set_token(mock_channel, mock_admin): mock_channel.return_value = True mock_admin.AdminServiceStub.return_value = True - client = _RawSynchronousFlyteClient(url='a.b.com', insecure=True) - client.set_access_token('abc') - assert client._metadata[0][1] == 'Bearer abc' + client = _RawSynchronousFlyteClient(url="a.b.com", insecure=True) + client.set_access_token("abc") + assert client._metadata[0][1] == "Bearer abc" -@mock.patch('flytekit.clis.sdk_in_container.basic_auth._requests') -@mock.patch('flytekit.clients.raw._credentials_access') +@mock.patch("flytekit.clis.sdk_in_container.basic_auth._requests") +@mock.patch("flytekit.clients.raw._credentials_access") def test_refresh_credentials_basic(mock_credentials_access, mock_requests): - mock_credentials_access.get_authorization_endpoints.return_value = _AuthorizationEndpoints('auth', 'token') + mock_credentials_access.get_authorization_endpoints.return_value = _AuthorizationEndpoints("auth", "token") response = mock.MagicMock() response.status_code = 200 response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""") @@ -31,7 +32,7 @@ def test_refresh_credentials_basic(mock_credentials_access, mock_requests): os.environ[_CREDENTIALS_SECRET.env_var] = "asdf12345" mock_client = mock.MagicMock() - mock_client.url.return_value = 'flyte.localhost' + mock_client.url.return_value = "flyte.localhost" _refresh_credentials_basic(mock_client) - mock_client.set_access_token.assert_called_with('abc') + mock_client.set_access_token.assert_called_with("abc") mock_credentials_access.get_authorization_endpoints.assert_called_with(mock_client.url) diff --git a/tests/flytekit/unit/common_tests/exceptions/test_base.py b/tests/flytekit/unit/common_tests/exceptions/test_base.py index db6b850dd0..58ef24b6dd 100644 --- a/tests/flytekit/unit/common_tests/exceptions/test_base.py +++ b/tests/flytekit/unit/common_tests/exceptions/test_base.py @@ -1,4 +1,5 @@ from __future__ import absolute_import + from flytekit.common.exceptions import base diff --git a/tests/flytekit/unit/common_tests/exceptions/test_scopes.py b/tests/flytekit/unit/common_tests/exceptions/test_scopes.py index 820deef557..37ac077268 100644 --- a/tests/flytekit/unit/common_tests/exceptions/test_scopes.py +++ b/tests/flytekit/unit/common_tests/exceptions/test_scopes.py @@ -1,8 +1,10 @@ from __future__ import absolute_import -from flytekit.common.exceptions import system, user, scopes -from flytekit.models.core import errors as _error_models + import pytest +from flytekit.common.exceptions import scopes, system, user +from flytekit.models.core import errors as _error_models + @scopes.user_entry_point def _user_func(ex_to_raise): diff --git a/tests/flytekit/unit/common_tests/exceptions/test_system.py b/tests/flytekit/unit/common_tests/exceptions/test_system.py index d159c72b97..ed574eea3a 100644 --- a/tests/flytekit/unit/common_tests/exceptions/test_system.py +++ b/tests/flytekit/unit/common_tests/exceptions/test_system.py @@ -1,5 +1,6 @@ from __future__ import absolute_import -from flytekit.common.exceptions import system, base + +from flytekit.common.exceptions import base, system def test_flyte_system_exception(): @@ -40,20 +41,22 @@ def test_flyte_entrypoint_not_loadable_exception(): try: raise system.FlyteEntrypointNotLoadable("fake.module", additional_msg="Shouldn't have used a fake module!") except Exception as e: - assert str(e) == "Entrypoint is not loadable! Could not load the module: 'fake.module' "\ - "due to error: Shouldn't have used a fake module!" + assert ( + str(e) == "Entrypoint is not loadable! Could not load the module: 'fake.module' " + "due to error: Shouldn't have used a fake module!" + ) assert type(e).error_code == "SYSTEM:UnloadableCode" assert isinstance(e, system.FlyteSystemException) try: raise system.FlyteEntrypointNotLoadable( - "fake.module", - task_name="secret_task", - additional_msg="Shouldn't have used a fake module!" + "fake.module", task_name="secret_task", additional_msg="Shouldn't have used a fake module!", ) except Exception as e: - assert str(e) == "Entrypoint is not loadable! Could not find the task: 'secret_task' in 'fake.module' " \ - "due to error: Shouldn't have used a fake module!" + assert ( + str(e) == "Entrypoint is not loadable! Could not find the task: 'secret_task' in 'fake.module' " + "due to error: Shouldn't have used a fake module!" + ) assert type(e).error_code == "SYSTEM:UnloadableCode" assert isinstance(e, system.FlyteSystemException) diff --git a/tests/flytekit/unit/common_tests/exceptions/test_user.py b/tests/flytekit/unit/common_tests/exceptions/test_user.py index 146401f793..ee4d02bf29 100644 --- a/tests/flytekit/unit/common_tests/exceptions/test_user.py +++ b/tests/flytekit/unit/common_tests/exceptions/test_user.py @@ -1,5 +1,6 @@ from __future__ import absolute_import -from flytekit.common.exceptions import user, base + +from flytekit.common.exceptions import base, user def test_flyte_user_exception(): @@ -14,7 +15,7 @@ def test_flyte_user_exception(): def test_flyte_type_exception(): try: - raise user.FlyteTypeException('int', 'float', received_value=1, additional_msg='That was a bad idea!') + raise user.FlyteTypeException("int", "float", received_value=1, additional_msg="That was a bad idea!") except Exception as e: assert str(e) == "Type error! Received: int with value: 1, Expected: float. That was a bad idea!" assert isinstance(e, TypeError) @@ -22,16 +23,20 @@ def test_flyte_type_exception(): assert isinstance(e, user.FlyteUserException) try: - raise user.FlyteTypeException('int', ('list', 'set'), received_value=1, additional_msg='That was a bad idea!') + raise user.FlyteTypeException( + "int", ("list", "set"), received_value=1, additional_msg="That was a bad idea!", + ) except Exception as e: - assert str(e) == "Type error! Received: int with value: 1, Expected one of: ('list', 'set'). That was a " \ - "bad idea!" + assert ( + str(e) == "Type error! Received: int with value: 1, Expected one of: ('list', 'set'). That was a " + "bad idea!" + ) assert isinstance(e, TypeError) assert type(e).error_code == "USER:TypeError" assert isinstance(e, user.FlyteUserException) try: - raise user.FlyteTypeException('int', 'float', additional_msg='That was a bad idea!') + raise user.FlyteTypeException("int", "float", additional_msg="That was a bad idea!") except Exception as e: assert str(e) == "Type error! Received: int, Expected: float. That was a bad idea!" assert isinstance(e, TypeError) @@ -39,10 +44,9 @@ def test_flyte_type_exception(): assert isinstance(e, user.FlyteUserException) try: - raise user.FlyteTypeException('int', ('list', 'set'), additional_msg='That was a bad idea!') + raise user.FlyteTypeException("int", ("list", "set"), additional_msg="That was a bad idea!") except Exception as e: - assert str(e) == "Type error! Received: int, Expected one of: ('list', 'set'). That was a " \ - "bad idea!" + assert str(e) == "Type error! Received: int, Expected one of: ('list', 'set'). That was a " "bad idea!" assert isinstance(e, TypeError) assert type(e).error_code == "USER:TypeError" assert isinstance(e, user.FlyteUserException) diff --git a/tests/flytekit/unit/common_tests/mixins/sample_registerable.py b/tests/flytekit/unit/common_tests/mixins/sample_registerable.py index f79119ddb7..a2e3ede79f 100644 --- a/tests/flytekit/unit/common_tests/mixins/sample_registerable.py +++ b/tests/flytekit/unit/common_tests/mixins/sample_registerable.py @@ -1,9 +1,10 @@ from __future__ import absolute_import -from flytekit.common.mixins import registerable as _registerable -from flytekit.common import interface as _interface, nodes as _nodes, sdk_bases as _sdk_bases import six as _six +from flytekit.common import sdk_bases as _sdk_bases +from flytekit.common.mixins import registerable as _registerable + class ExampleRegisterable(_six.with_metaclass(_sdk_bases.ExtendedSdkType, _registerable.RegisterableEntity)): def __init__(self, *args, **kwargs): diff --git a/tests/flytekit/unit/common_tests/mixins/test_registerable.py b/tests/flytekit/unit/common_tests/mixins/test_registerable.py index 8336bfa11a..8f0d82b19e 100644 --- a/tests/flytekit/unit/common_tests/mixins/test_registerable.py +++ b/tests/flytekit/unit/common_tests/mixins/test_registerable.py @@ -9,5 +9,7 @@ def test_instance_tracker(): def test_auto_name_assignment(): _sample_registerable.example.auto_assign_name() - assert _sample_registerable.example.platform_valid_name == \ - "tests.flytekit.unit.common_tests.mixins.sample_registerable.example" + assert ( + _sample_registerable.example.platform_valid_name + == "tests.flytekit.unit.common_tests.mixins.sample_registerable.example" + ) diff --git a/tests/flytekit/unit/common_tests/tasks/spark/test_spark_task.py b/tests/flytekit/unit/common_tests/tasks/spark/test_spark_task.py index d9e3139b60..29e9e3254a 100644 --- a/tests/flytekit/unit/common_tests/tasks/spark/test_spark_task.py +++ b/tests/flytekit/unit/common_tests/tasks/spark/test_spark_task.py @@ -1,10 +1,10 @@ from __future__ import absolute_import -from flytekit.sdk.tasks import spark_task, outputs -from flytekit.sdk.types import Types - from six.moves import range +from flytekit.sdk.tasks import outputs, spark_task +from flytekit.sdk.types import Types + # This file is in a subdirectory to make it easier to exclude when not running in a container # and pyspark is not available @@ -14,8 +14,8 @@ def my_spark_task(wf, sc, out): def _inside(p): return p < 1000 - count = sc.parallelize(range(0, 10000)) \ - .filter(_inside).count() + + count = sc.parallelize(range(0, 10000)).filter(_inside).count() out.set(count) @@ -26,14 +26,14 @@ def my_spark_task2(wf, sc, out): # modules. def _inside(p): return p < 500 - count = sc.parallelize(range(0, 10000)) \ - .filter(_inside).count() + + count = sc.parallelize(range(0, 10000)).filter(_inside).count() out.set(count) def test_basic_spark_execution(): outputs = my_spark_task.unit_test() - assert outputs['out'] == 1000 + assert outputs["out"] == 1000 outputs = my_spark_task2.unit_test() - assert outputs['out'] == 500 + assert outputs["out"] == 500 diff --git a/tests/flytekit/unit/common_tests/tasks/test_sdk_runnable.py b/tests/flytekit/unit/common_tests/tasks/test_sdk_runnable.py index 57aa1485db..7c4763046a 100644 --- a/tests/flytekit/unit/common_tests/tasks/test_sdk_runnable.py +++ b/tests/flytekit/unit/common_tests/tasks/test_sdk_runnable.py @@ -1,15 +1,15 @@ from __future__ import absolute_import +import pytest as _pytest + from flytekit.common import constants as _common_constants from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.tasks import sdk_runnable from flytekit.common.types import primitives from flytekit.models import interface -import pytest as _pytest def test_basic_unit_test(): - def add_one(wf_params, value_in, value_out): value_out.set(value_in + 1) @@ -33,10 +33,10 @@ def add_one(wf_params, value_in, value_out): {}, None, ) - t.add_inputs({'value_in': interface.Variable(primitives.Integer.to_flyte_literal_type(), "")}) - t.add_outputs({'value_out': interface.Variable(primitives.Integer.to_flyte_literal_type(), "")}) + t.add_inputs({"value_in": interface.Variable(primitives.Integer.to_flyte_literal_type(), "")}) + t.add_outputs({"value_out": interface.Variable(primitives.Integer.to_flyte_literal_type(), "")}) out = t.unit_test(value_in=1) - assert out['value_out'] == 2 + assert out["value_out"] == 2 with _pytest.raises(_user_exceptions.FlyteAssertion) as e: t() diff --git a/tests/flytekit/unit/common_tests/tasks/test_task.py b/tests/flytekit/unit/common_tests/tasks/test_task.py index 094448f7f7..1c6d52b9fc 100644 --- a/tests/flytekit/unit/common_tests/tasks/test_task.py +++ b/tests/flytekit/unit/common_tests/tasks/test_task.py @@ -1,31 +1,30 @@ from __future__ import absolute_import -import pytest as _pytest import os as _os -from mock import patch as _patch, MagicMock as _MagicMock -from flytekit.configuration import TemporaryConfiguration +import pytest as _pytest +from flyteidl.admin import task_pb2 as _admin_task_pb2 +from mock import MagicMock as _MagicMock +from mock import patch as _patch + from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.tasks import task as _task +from flytekit.common.tasks.presto_task import SdkPrestoTask from flytekit.common.types import primitives +from flytekit.configuration import TemporaryConfiguration from flytekit.models import task as _task_models from flytekit.models.core import identifier as _identifier -from flytekit.sdk.tasks import python_task, inputs, outputs -from flyteidl.admin import task_pb2 as _admin_task_pb2 -from flytekit.common.tasks.presto_task import SdkPrestoTask +from flytekit.sdk.tasks import inputs, outputs, python_task from flytekit.sdk.types import Types @_patch("flytekit.engines.loader.get_engine") def test_fetch_latest(mock_get_engine): admin_task = _task_models.Task( - _identifier.Identifier(_identifier.ResourceType.TASK, "p1", "d1", "n1", "v1"), - _MagicMock(), + _identifier.Identifier(_identifier.ResourceType.TASK, "p1", "d1", "n1", "v1"), _MagicMock(), ) mock_engine = _MagicMock() - mock_engine.fetch_latest_task = _MagicMock( - return_value=admin_task - ) + mock_engine.fetch_latest_task = _MagicMock(return_value=admin_task) mock_get_engine.return_value = mock_engine task = _task.SdkTask.fetch_latest("p1", "d1", "n1") assert task.id == admin_task.id @@ -34,9 +33,7 @@ def test_fetch_latest(mock_get_engine): @_patch("flytekit.engines.loader.get_engine") def test_fetch_latest_not_exist(mock_get_engine): mock_engine = _MagicMock() - mock_engine.fetch_latest_task = _MagicMock( - return_value=None - ) + mock_engine.fetch_latest_task = _MagicMock(return_value=None) mock_get_engine.return_value = mock_engine with _pytest.raises(_user_exceptions.FlyteEntityNotExistException): _task.SdkTask.fetch_latest("p1", "d1", "n1") @@ -46,6 +43,7 @@ def get_sample_task(): """ :rtype: flytekit.common.tasks.task.SdkTask """ + @inputs(a=primitives.Integer) @outputs(b=primitives.Integer) @python_task() @@ -58,18 +56,14 @@ def my_task(wf_params, a, b): def test_task_serialization(): t = get_sample_task() with TemporaryConfiguration( - _os.path.join(_os.path.dirname(_os.path.realpath(__file__)), '../../../common/configs/local.config'), - internal_overrides={ - 'image': 'myflyteimage:v123', - 'project': 'myflyteproject', - 'domain': 'development' - } + _os.path.join(_os.path.dirname(_os.path.realpath(__file__)), "../../../common/configs/local.config",), + internal_overrides={"image": "myflyteimage:v123", "project": "myflyteproject", "domain": "development"}, ): s = t.serialize() assert isinstance(s, _admin_task_pb2.TaskSpec) - assert s.template.id.name == 'tests.flytekit.unit.common_tests.tasks.test_task.my_task' - assert s.template.container.image == 'myflyteimage:v123' + assert s.template.id.name == "tests.flytekit.unit.common_tests.tasks.test_task.my_task" + assert s.template.container.image == "myflyteimage:v123" schema = Types.Schema([("a", Types.String), ("b", Types.Integer)]) @@ -94,11 +88,15 @@ def test_task_produce_deterministic_version(): output_schema=schema, routing_group="{{ .Inputs.rg }}", ) - assert containerless_task._produce_deterministic_version() ==\ - identical_containerless_task._produce_deterministic_version() + assert ( + containerless_task._produce_deterministic_version() + == identical_containerless_task._produce_deterministic_version() + ) - assert containerless_task._produce_deterministic_version() !=\ - different_containerless_task._produce_deterministic_version() + assert ( + containerless_task._produce_deterministic_version() + != different_containerless_task._produce_deterministic_version() + ) with _pytest.raises(Exception): get_sample_task()._produce_deterministic_version() diff --git a/tests/flytekit/unit/common_tests/test_interface.py b/tests/flytekit/unit/common_tests/test_interface.py index 3b2f0d002a..2adaf237dc 100644 --- a/tests/flytekit/unit/common_tests/test_interface.py +++ b/tests/flytekit/unit/common_tests/test_interface.py @@ -1,16 +1,16 @@ from __future__ import absolute_import + +import pytest + from flytekit.common import interface from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.types import primitives, containers -import pytest +from flytekit.common.types import containers, primitives def test_binding_data_primitive_static(): upstream_nodes = set() bd = interface.BindingData.from_python_std( - primitives.Float.to_flyte_literal_type(), - 3.0, - upstream_nodes=upstream_nodes + primitives.Float.to_flyte_literal_type(), 3.0, upstream_nodes=upstream_nodes ) assert len(upstream_nodes) == 0 @@ -23,29 +23,25 @@ def test_binding_data_primitive_static(): with pytest.raises(_user_exceptions.FlyteTypeException): interface.BindingData.from_python_std( - primitives.Float.to_flyte_literal_type(), - 'abc', + primitives.Float.to_flyte_literal_type(), "abc", ) with pytest.raises(_user_exceptions.FlyteTypeException): interface.BindingData.from_python_std( - primitives.Float.to_flyte_literal_type(), - [1.0, 2.0, 3.0], + primitives.Float.to_flyte_literal_type(), [1.0, 2.0, 3.0], ) def test_binding_data_list_static(): upstream_nodes = set() bd = interface.BindingData.from_python_std( - containers.List(primitives.String).to_flyte_literal_type(), - ['abc', 'cde'], - upstream_nodes=upstream_nodes + containers.List(primitives.String).to_flyte_literal_type(), ["abc", "cde"], upstream_nodes=upstream_nodes, ) assert len(upstream_nodes) == 0 assert bd.promise is None - assert bd.collection.bindings[0].scalar.primitive.string_value == 'abc' - assert bd.collection.bindings[1].scalar.primitive.string_value == 'cde' + assert bd.collection.bindings[0].scalar.primitive.string_value == "abc" + assert bd.collection.bindings[1].scalar.primitive.string_value == "cde" assert bd.map is None assert bd.scalar is None @@ -53,14 +49,12 @@ def test_binding_data_list_static(): with pytest.raises(_user_exceptions.FlyteTypeException): interface.BindingData.from_python_std( - containers.List(primitives.String).to_flyte_literal_type(), - 'abc', + containers.List(primitives.String).to_flyte_literal_type(), "abc", ) with pytest.raises(_user_exceptions.FlyteTypeException): interface.BindingData.from_python_std( - containers.List(primitives.String).to_flyte_literal_type(), - [1.0, 2.0, 3.0] + containers.List(primitives.String).to_flyte_literal_type(), [1.0, 2.0, 3.0] ) @@ -68,16 +62,16 @@ def test_binding_generic_map_static(): upstream_nodes = set() bd = interface.BindingData.from_python_std( primitives.Generic.to_flyte_literal_type(), - {'a': 'hi', 'b': [1, 2, 3], 'c': {'d': 'e'}}, - upstream_nodes=upstream_nodes + {"a": "hi", "b": [1, 2, 3], "c": {"d": "e"}}, + upstream_nodes=upstream_nodes, ) assert len(upstream_nodes) == 0 assert bd.promise is None assert bd.map is None - assert bd.scalar.generic['a'] == 'hi' - assert bd.scalar.generic['b'].values[0].number_value == 1.0 - assert bd.scalar.generic['b'].values[1].number_value == 2.0 - assert bd.scalar.generic['b'].values[2].number_value == 3.0 - assert bd.scalar.generic['c']['d'] == 'e' + assert bd.scalar.generic["a"] == "hi" + assert bd.scalar.generic["b"].values[0].number_value == 1.0 + assert bd.scalar.generic["b"].values[1].number_value == 2.0 + assert bd.scalar.generic["b"].values[2].number_value == 3.0 + assert bd.scalar.generic["c"]["d"] == "e" assert interface.BindingData.from_flyte_idl(bd.to_flyte_idl()) == bd diff --git a/tests/flytekit/unit/common_tests/test_launch_plan.py b/tests/flytekit/unit/common_tests/test_launch_plan.py index 8b85b74e78..83376e4949 100644 --- a/tests/flytekit/unit/common_tests/test_launch_plan.py +++ b/tests/flytekit/unit/common_tests/test_launch_plan.py @@ -1,140 +1,144 @@ from __future__ import absolute_import -from flytekit import configuration as _configuration -from flytekit.common import notifications as _notifications, schedules as _schedules, launch_plan as _launch_plan -from flytekit.common.exceptions import user as _user_exceptions -from flytekit.models import types as _type_models, common as _common_models -from flytekit.models.core import execution as _execution, identifier as _identifier -from flytekit.sdk import types as _types, workflow as _workflow + import os as _os + import pytest as _pytest +from flytekit import configuration as _configuration +from flytekit.common import launch_plan as _launch_plan +from flytekit.common import notifications as _notifications +from flytekit.common import schedules as _schedules +from flytekit.common.exceptions import user as _user_exceptions +from flytekit.models import common as _common_models +from flytekit.models import types as _type_models +from flytekit.models.core import execution as _execution +from flytekit.models.core import identifier as _identifier +from flytekit.sdk import types as _types +from flytekit.sdk import workflow as _workflow + def test_default_assumable_iam_role(): with _configuration.TemporaryConfiguration( - _os.path.join(_os.path.dirname(_os.path.realpath(__file__)), '../../common/configs/local.config') + _os.path.join(_os.path.dirname(_os.path.realpath(__file__)), "../../common/configs/local.config",) ): workflow_to_test = _workflow.workflow( {}, inputs={ - 'required_input': _workflow.Input(_types.Types.Integer), - 'default_input': _workflow.Input(_types.Types.Integer, default=5) - } + "required_input": _workflow.Input(_types.Types.Integer), + "default_input": _workflow.Input(_types.Types.Integer, default=5), + }, ) lp = workflow_to_test.create_launch_plan() - assert lp.auth_role.assumable_iam_role == 'arn:aws:iam::ABC123:role/my-flyte-role' + assert lp.auth_role.assumable_iam_role == "arn:aws:iam::ABC123:role/my-flyte-role" def test_hard_coded_assumable_iam_role(): workflow_to_test = _workflow.workflow( {}, inputs={ - 'required_input': _workflow.Input(_types.Types.Integer), - 'default_input': _workflow.Input(_types.Types.Integer, default=5) - } + "required_input": _workflow.Input(_types.Types.Integer), + "default_input": _workflow.Input(_types.Types.Integer, default=5), + }, ) - lp = workflow_to_test.create_launch_plan(assumable_iam_role='override') - assert lp.auth_role.assumable_iam_role == 'override' + lp = workflow_to_test.create_launch_plan(assumable_iam_role="override") + assert lp.auth_role.assumable_iam_role == "override" def test_default_deprecated_role(): with _configuration.TemporaryConfiguration( - _os.path.join(_os.path.dirname(_os.path.realpath(__file__)), '../../common/configs/deprecated_local.config') + _os.path.join(_os.path.dirname(_os.path.realpath(__file__)), "../../common/configs/deprecated_local.config",) ): workflow_to_test = _workflow.workflow( {}, inputs={ - 'required_input': _workflow.Input(_types.Types.Integer), - 'default_input': _workflow.Input(_types.Types.Integer, default=5) - } + "required_input": _workflow.Input(_types.Types.Integer), + "default_input": _workflow.Input(_types.Types.Integer, default=5), + }, ) lp = workflow_to_test.create_launch_plan() - assert lp.auth_role.assumable_iam_role == 'arn:aws:iam::ABC123:role/my-flyte-role' + assert lp.auth_role.assumable_iam_role == "arn:aws:iam::ABC123:role/my-flyte-role" def test_hard_coded_deprecated_role(): workflow_to_test = _workflow.workflow( {}, inputs={ - 'required_input': _workflow.Input(_types.Types.Integer), - 'default_input': _workflow.Input(_types.Types.Integer, default=5) - } + "required_input": _workflow.Input(_types.Types.Integer), + "default_input": _workflow.Input(_types.Types.Integer, default=5), + }, ) - lp = workflow_to_test.create_launch_plan(role='override') - assert lp.auth_role.assumable_iam_role == 'override' + lp = workflow_to_test.create_launch_plan(role="override") + assert lp.auth_role.assumable_iam_role == "override" def test_kubernetes_service_account(): workflow_to_test = _workflow.workflow( {}, inputs={ - 'required_input': _workflow.Input(_types.Types.Integer), - 'default_input': _workflow.Input(_types.Types.Integer, default=5) - } + "required_input": _workflow.Input(_types.Types.Integer), + "default_input": _workflow.Input(_types.Types.Integer, default=5), + }, ) - lp = workflow_to_test.create_launch_plan(kubernetes_service_account='kube-service-acct') - assert lp.auth_role.kubernetes_service_account == 'kube-service-acct' + lp = workflow_to_test.create_launch_plan(kubernetes_service_account="kube-service-acct") + assert lp.auth_role.kubernetes_service_account == "kube-service-acct" def test_fixed_inputs(): workflow_to_test = _workflow.workflow( {}, inputs={ - 'required_input': _workflow.Input(_types.Types.Integer), - 'default_input': _workflow.Input(_types.Types.Integer, default=5) - } - ) - lp = workflow_to_test.create_launch_plan( - fixed_inputs={'required_input': 4} + "required_input": _workflow.Input(_types.Types.Integer), + "default_input": _workflow.Input(_types.Types.Integer, default=5), + }, ) + lp = workflow_to_test.create_launch_plan(fixed_inputs={"required_input": 4}) assert len(lp.fixed_inputs.literals) == 1 - assert lp.fixed_inputs.literals['required_input'].scalar.primitive.integer == 4 + assert lp.fixed_inputs.literals["required_input"].scalar.primitive.integer == 4 assert len(lp.default_inputs.parameters) == 1 - assert lp.default_inputs.parameters['default_input'].default.scalar.primitive.integer == 5 + assert lp.default_inputs.parameters["default_input"].default.scalar.primitive.integer == 5 def test_redefining_inputs_good(): workflow_to_test = _workflow.workflow( {}, inputs={ - 'required_input': _workflow.Input(_types.Types.Integer), - 'default_input': _workflow.Input(_types.Types.Integer, default=5) - } + "required_input": _workflow.Input(_types.Types.Integer), + "default_input": _workflow.Input(_types.Types.Integer, default=5), + }, ) lp = workflow_to_test.create_launch_plan( - default_inputs={'required_input': _workflow.Input(_types.Types.Integer, default=900)} + default_inputs={"required_input": _workflow.Input(_types.Types.Integer, default=900)} ) assert len(lp.fixed_inputs.literals) == 0 assert len(lp.default_inputs.parameters) == 2 - assert lp.default_inputs.parameters['required_input'].default.scalar.primitive.integer == 900 - assert lp.default_inputs.parameters['default_input'].default.scalar.primitive.integer == 5 + assert lp.default_inputs.parameters["required_input"].default.scalar.primitive.integer == 900 + assert lp.default_inputs.parameters["default_input"].default.scalar.primitive.integer == 5 def test_no_additional_inputs(): workflow_to_test = _workflow.workflow( {}, inputs={ - 'required_input': _workflow.Input(_types.Types.Integer), - 'default_input': _workflow.Input(_types.Types.Integer, default=5) - } + "required_input": _workflow.Input(_types.Types.Integer), + "default_input": _workflow.Input(_types.Types.Integer, default=5), + }, ) lp = workflow_to_test.create_launch_plan() assert len(lp.fixed_inputs.literals) == 0 - assert lp.default_inputs.parameters['default_input'].default.scalar.primitive.integer == 5 - assert lp.default_inputs.parameters['required_input'].required is True + assert lp.default_inputs.parameters["default_input"].default.scalar.primitive.integer == 5 + assert lp.default_inputs.parameters["required_input"].required is True def test_schedule(): workflow_to_test = _workflow.workflow( {}, inputs={ - 'required_input': _workflow.Input(_types.Types.Integer), - 'default_input': _workflow.Input(_types.Types.Integer, default=5) - } + "required_input": _workflow.Input(_types.Types.Integer), + "default_input": _workflow.Input(_types.Types.Integer, default=5), + }, ) lp = workflow_to_test.create_launch_plan( - fixed_inputs={'required_input': 5}, - schedule=_schedules.CronSchedule("* * ? * * *"), - role='what' + fixed_inputs={"required_input": 5}, schedule=_schedules.CronSchedule("* * ? * * *"), role="what", ) assert lp.entity_metadata.schedule.kickoff_time_input_arg is None assert lp.entity_metadata.schedule.cron_expression == "* * ? * * *" @@ -145,12 +149,12 @@ def test_no_schedule(): workflow_to_test = _workflow.workflow( {}, inputs={ - 'required_input': _workflow.Input(_types.Types.Integer), - 'default_input': _workflow.Input(_types.Types.Integer, default=5) - } + "required_input": _workflow.Input(_types.Types.Integer), + "default_input": _workflow.Input(_types.Types.Integer, default=5), + }, ) lp = workflow_to_test.create_launch_plan() - assert lp.entity_metadata.schedule.kickoff_time_input_arg == '' + assert lp.entity_metadata.schedule.kickoff_time_input_arg == "" assert lp.entity_metadata.schedule.schedule_expression is None assert not lp.is_scheduled @@ -159,15 +163,14 @@ def test_schedule_pointing_to_datetime(): workflow_to_test = _workflow.workflow( {}, inputs={ - 'required_input': _workflow.Input(_types.Types.Datetime), - 'default_input': _workflow.Input(_types.Types.Integer, default=5) - } + "required_input": _workflow.Input(_types.Types.Datetime), + "default_input": _workflow.Input(_types.Types.Integer, default=5), + }, ) lp = workflow_to_test.create_launch_plan( - schedule=_schedules.CronSchedule("* * ? * * *", kickoff_time_input_arg='required_input'), - role='what' + schedule=_schedules.CronSchedule("* * ? * * *", kickoff_time_input_arg="required_input"), role="what", ) - assert lp.entity_metadata.schedule.kickoff_time_input_arg == 'required_input' + assert lp.entity_metadata.schedule.kickoff_time_input_arg == "required_input" assert lp.entity_metadata.schedule.cron_expression == "* * ? * * *" @@ -175,17 +178,15 @@ def test_notifications(): workflow_to_test = _workflow.workflow( {}, inputs={ - 'required_input': _workflow.Input(_types.Types.Integer), - 'default_input': _workflow.Input(_types.Types.Integer, default=5) - } + "required_input": _workflow.Input(_types.Types.Integer), + "default_input": _workflow.Input(_types.Types.Integer, default=5), + }, ) lp = workflow_to_test.create_launch_plan( - notifications=[ - _notifications.PagerDuty([_execution.WorkflowExecutionPhase.FAILED], ['me@myplace.com']) - ] + notifications=[_notifications.PagerDuty([_execution.WorkflowExecutionPhase.FAILED], ["me@myplace.com"])] ) assert len(lp.entity_metadata.notifications) == 1 - assert lp.entity_metadata.notifications[0].pager_duty.recipients_email == ['me@myplace.com'] + assert lp.entity_metadata.notifications[0].pager_duty.recipients_email == ["me@myplace.com"] assert lp.entity_metadata.notifications[0].phases == [_execution.WorkflowExecutionPhase.FAILED] @@ -193,9 +194,9 @@ def test_no_notifications(): workflow_to_test = _workflow.workflow( {}, inputs={ - 'required_input': _workflow.Input(_types.Types.Integer), - 'default_input': _workflow.Input(_types.Types.Integer, default=5) - } + "required_input": _workflow.Input(_types.Types.Integer), + "default_input": _workflow.Input(_types.Types.Integer, default=5), + }, ) lp = workflow_to_test.create_launch_plan() assert len(lp.entity_metadata.notifications) == 0 @@ -205,12 +206,10 @@ def test_launch_plan_node(): workflow_to_test = _workflow.workflow( {}, inputs={ - 'required_input': _workflow.Input(_types.Types.Integer), - 'default_input': _workflow.Input(_types.Types.Integer, default=5) + "required_input": _workflow.Input(_types.Types.Integer), + "default_input": _workflow.Input(_types.Types.Integer, default=5), }, - outputs={ - 'out': _workflow.Output([1, 2, 3], sdk_type=[_types.Types.Integer]) - } + outputs={"out": _workflow.Output([1, 2, 3], sdk_type=[_types.Types.Integer])}, ) lp = workflow_to_test.create_launch_plan() @@ -224,7 +223,7 @@ def test_launch_plan_node(): # Test that type checking works with _pytest.raises(_user_exceptions.FlyteTypeException): - lp(required_input='abc', default_input=1) + lp(required_input="abc", default_input=1) # Test that bad arg name is detected with _pytest.raises(_user_exceptions.FlyteAssertion): @@ -232,43 +231,43 @@ def test_launch_plan_node(): # Test default input is accounted for n = lp(required_input=10) - assert n.inputs[0].var == 'default_input' + assert n.inputs[0].var == "default_input" assert n.inputs[0].binding.scalar.primitive.integer == 5 - assert n.inputs[1].var == 'required_input' + assert n.inputs[1].var == "required_input" assert n.inputs[1].binding.scalar.primitive.integer == 10 # Test default input is overridden n = lp(required_input=10, default_input=50) - assert n.inputs[0].var == 'default_input' + assert n.inputs[0].var == "default_input" assert n.inputs[0].binding.scalar.primitive.integer == 50 - assert n.inputs[1].var == 'required_input' + assert n.inputs[1].var == "required_input" assert n.inputs[1].binding.scalar.primitive.integer == 10 # Test that launch plan ID ref is flexible - lp._id = 'fake' - assert n.workflow_node.launchplan_ref == 'fake' + lp._id = "fake" + assert n.workflow_node.launchplan_ref == "fake" lp._id = None # Test that outputs are promised - n.assign_id_and_return('node-id') - assert n.outputs['out'].sdk_type.to_flyte_literal_type().collection_type.simple == _type_models.SimpleType.INTEGER - assert n.outputs['out'].var == 'out' - assert n.outputs['out'].node_id == 'node-id' + n.assign_id_and_return("node-id") + assert n.outputs["out"].sdk_type.to_flyte_literal_type().collection_type.simple == _type_models.SimpleType.INTEGER + assert n.outputs["out"].var == "out" + assert n.outputs["out"].node_id == "node-id" def test_labels(): workflow_to_test = _workflow.workflow( {}, inputs={ - 'required_input': _workflow.Input(_types.Types.Integer), - 'default_input': _workflow.Input(_types.Types.Integer, default=5) - } + "required_input": _workflow.Input(_types.Types.Integer), + "default_input": _workflow.Input(_types.Types.Integer, default=5), + }, ) lp = workflow_to_test.create_launch_plan( - fixed_inputs={'required_input': 5}, + fixed_inputs={"required_input": 5}, schedule=_schedules.CronSchedule("* * ? * * *"), - role='what', - labels=_common_models.Labels({"my": "label"}) + role="what", + labels=_common_models.Labels({"my": "label"}), ) assert lp.labels.values == {"my": "label"} @@ -277,15 +276,15 @@ def test_annotations(): workflow_to_test = _workflow.workflow( {}, inputs={ - 'required_input': _workflow.Input(_types.Types.Integer), - 'default_input': _workflow.Input(_types.Types.Integer, default=5) - } + "required_input": _workflow.Input(_types.Types.Integer), + "default_input": _workflow.Input(_types.Types.Integer, default=5), + }, ) lp = workflow_to_test.create_launch_plan( - fixed_inputs={'required_input': 5}, + fixed_inputs={"required_input": 5}, schedule=_schedules.CronSchedule("* * ? * * *"), - role='what', - annotations=_common_models.Annotations({"my": "annotation"}) + role="what", + annotations=_common_models.Annotations({"my": "annotation"}), ) assert lp.annotations.values == {"my": "annotation"} @@ -294,44 +293,37 @@ def test_serialize(): workflow_to_test = _workflow.workflow( {}, inputs={ - 'required_input': _workflow.Input(_types.Types.Integer), - 'default_input': _workflow.Input(_types.Types.Integer, default=5) - } + "required_input": _workflow.Input(_types.Types.Integer), + "default_input": _workflow.Input(_types.Types.Integer, default=5), + }, ) workflow_to_test._id = _identifier.Identifier(_identifier.ResourceType.WORKFLOW, "p", "d", "n", "v") - lp = workflow_to_test.create_launch_plan( - fixed_inputs={'required_input': 5}, - role='iam_role', - ) + lp = workflow_to_test.create_launch_plan(fixed_inputs={"required_input": 5}, role="iam_role",) with _configuration.TemporaryConfiguration( - _os.path.join(_os.path.dirname(_os.path.realpath(__file__)), '../../common/configs/local.config'), - internal_overrides={ - 'image': 'myflyteimage:v123', - 'project': 'myflyteproject', - 'domain': 'development' - } + _os.path.join(_os.path.dirname(_os.path.realpath(__file__)), "../../common/configs/local.config",), + internal_overrides={"image": "myflyteimage:v123", "project": "myflyteproject", "domain": "development"}, ): s = lp.serialize() assert s.workflow_id == _identifier.Identifier(_identifier.ResourceType.WORKFLOW, "p", "d", "n", "v").to_flyte_idl() - assert s.auth_role.assumable_iam_role == 'iam_role' - assert s.default_inputs.parameters['default_input'].default.scalar.primitive.integer == 5 + assert s.auth_role.assumable_iam_role == "iam_role" + assert s.default_inputs.parameters["default_input"].default.scalar.primitive.integer == 5 def test_promote_from_model(): workflow_to_test = _workflow.workflow( {}, inputs={ - 'required_input': _workflow.Input(_types.Types.Integer), - 'default_input': _workflow.Input(_types.Types.Integer, default=5) - } + "required_input": _workflow.Input(_types.Types.Integer), + "default_input": _workflow.Input(_types.Types.Integer, default=5), + }, ) workflow_to_test._id = _identifier.Identifier(_identifier.ResourceType.WORKFLOW, "p", "d", "n", "v") lp = workflow_to_test.create_launch_plan( - fixed_inputs={'required_input': 5}, + fixed_inputs={"required_input": 5}, schedule=_schedules.CronSchedule("* * ? * * *"), - role='what', - labels=_common_models.Labels({"my": "label"}) + role="what", + labels=_common_models.Labels({"my": "label"}), ) with _pytest.raises(_user_exceptions.FlyteAssertion): @@ -347,17 +339,14 @@ def test_raw_data_output_prefix(): workflow_to_test = _workflow.workflow( {}, inputs={ - 'required_input': _workflow.Input(_types.Types.Integer), - 'default_input': _workflow.Input(_types.Types.Integer, default=5) - } + "required_input": _workflow.Input(_types.Types.Integer), + "default_input": _workflow.Input(_types.Types.Integer, default=5), + }, ) lp = workflow_to_test.create_launch_plan( - fixed_inputs={'required_input': 5}, - raw_output_data_prefix='s3://bucket-name', + fixed_inputs={"required_input": 5}, raw_output_data_prefix="s3://bucket-name", ) - assert lp.raw_output_data_config.output_location_prefix == 's3://bucket-name' + assert lp.raw_output_data_config.output_location_prefix == "s3://bucket-name" - lp2 = workflow_to_test.create_launch_plan( - fixed_inputs={'required_input': 5}, - ) - assert lp2.raw_output_data_config.output_location_prefix == '' + lp2 = workflow_to_test.create_launch_plan(fixed_inputs={"required_input": 5},) + assert lp2.raw_output_data_config.output_location_prefix == "" diff --git a/tests/flytekit/unit/common_tests/test_nodes.py b/tests/flytekit/unit/common_tests/test_nodes.py index 8e8dfa97f2..fe051de186 100644 --- a/tests/flytekit/unit/common_tests/test_nodes.py +++ b/tests/flytekit/unit/common_tests/test_nodes.py @@ -1,13 +1,20 @@ from __future__ import absolute_import -from flytekit.common import nodes as _nodes, interface as _interface, component_nodes as _component_nodes -from flytekit.models.core import workflow as _core_workflow_models, identifier as _identifier -from flytekit.models import literals as _literals -from flytekit.sdk import tasks as _tasks, types as _types, workflow as _workflow -from flytekit.common.exceptions import system as _system_exceptions import datetime as _datetime + import pytest as _pytest +from flytekit.common import component_nodes as _component_nodes +from flytekit.common import interface as _interface +from flytekit.common import nodes as _nodes +from flytekit.common.exceptions import system as _system_exceptions +from flytekit.models import literals as _literals +from flytekit.models.core import identifier as _identifier +from flytekit.models.core import workflow as _core_workflow_models +from flytekit.sdk import tasks as _tasks +from flytekit.sdk import types as _types +from flytekit.sdk import workflow as _workflow + def test_sdk_node_from_task(): @_tasks.inputs(a=_types.Types.Integer) @@ -17,68 +24,66 @@ def testy_test(wf_params, a, b): pass n = _nodes.SdkNode( - 'n', + "n", [], [ _literals.Binding( - 'a', - _interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), 3) + "a", _interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), 3), ) ], _core_workflow_models.NodeMetadata("abc", _datetime.timedelta(minutes=15), _literals.RetryStrategy(3)), sdk_task=testy_test, sdk_workflow=None, sdk_launch_plan=None, - sdk_branch=None + sdk_branch=None, ) - assert n.id == 'n' + assert n.id == "n" assert len(n.inputs) == 1 - assert n.inputs[0].var == 'a' + assert n.inputs[0].var == "a" assert n.inputs[0].binding.scalar.primitive.integer == 3 assert len(n.outputs) == 1 - assert 'b' in n.outputs - assert n.outputs['b'].node_id == 'n' - assert n.outputs['b'].var == 'b' - assert n.outputs['b'].sdk_node == n - assert n.outputs['b'].sdk_type == _types.Types.Integer - assert n.metadata.name == 'abc' + assert "b" in n.outputs + assert n.outputs["b"].node_id == "n" + assert n.outputs["b"].var == "b" + assert n.outputs["b"].sdk_node == n + assert n.outputs["b"].sdk_type == _types.Types.Integer + assert n.metadata.name == "abc" assert n.metadata.retries.retries == 3 - assert n.metadata.interruptible == False + assert n.metadata.interruptible is False assert len(n.upstream_nodes) == 0 assert len(n.upstream_node_ids) == 0 assert len(n.output_aliases) == 0 n2 = _nodes.SdkNode( - 'n2', + "n2", [n], [ _literals.Binding( - 'a', - _interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), n.outputs.b) + "a", _interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), n.outputs.b), ) ], _core_workflow_models.NodeMetadata("abc2", _datetime.timedelta(minutes=15), _literals.RetryStrategy(3)), sdk_task=testy_test, sdk_workflow=None, sdk_launch_plan=None, - sdk_branch=None + sdk_branch=None, ) - assert n2.id == 'n2' + assert n2.id == "n2" assert len(n2.inputs) == 1 - assert n2.inputs[0].var == 'a' - assert n2.inputs[0].binding.promise.var == 'b' - assert n2.inputs[0].binding.promise.node_id == 'n' + assert n2.inputs[0].var == "a" + assert n2.inputs[0].binding.promise.var == "b" + assert n2.inputs[0].binding.promise.node_id == "n" assert len(n2.outputs) == 1 - assert 'b' in n2.outputs - assert n2.outputs['b'].node_id == 'n2' - assert n2.outputs['b'].var == 'b' - assert n2.outputs['b'].sdk_node == n2 - assert n2.outputs['b'].sdk_type == _types.Types.Integer - assert n2.metadata.name == 'abc2' + assert "b" in n2.outputs + assert n2.outputs["b"].node_id == "n2" + assert n2.outputs["b"].var == "b" + assert n2.outputs["b"].sdk_node == n2 + assert n2.outputs["b"].sdk_type == _types.Types.Integer + assert n2.metadata.name == "abc2" assert n2.metadata.retries.retries == 3 - assert 'n' in n2.upstream_node_ids + assert "n" in n2.upstream_node_ids assert n in n2.upstream_nodes assert len(n2.upstream_nodes) == 1 assert len(n2.upstream_node_ids) == 1 @@ -86,38 +91,37 @@ def testy_test(wf_params, a, b): # Test right shift operator and late binding n3 = _nodes.SdkNode( - 'n3', + "n3", [], [ _literals.Binding( - 'a', - _interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), 3) + "a", _interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), 3), ) ], _core_workflow_models.NodeMetadata("abc3", _datetime.timedelta(minutes=15), _literals.RetryStrategy(3)), sdk_task=testy_test, sdk_workflow=None, sdk_launch_plan=None, - sdk_branch=None + sdk_branch=None, ) n2 >> n3 n >> n2 >> n3 n3 << n2 n3 << n2 << n - assert n3.id == 'n3' + assert n3.id == "n3" assert len(n3.inputs) == 1 - assert n3.inputs[0].var == 'a' + assert n3.inputs[0].var == "a" assert n3.inputs[0].binding.scalar.primitive.integer == 3 assert len(n3.outputs) == 1 - assert 'b' in n3.outputs - assert n3.outputs['b'].node_id == 'n3' - assert n3.outputs['b'].var == 'b' - assert n3.outputs['b'].sdk_node == n3 - assert n3.outputs['b'].sdk_type == _types.Types.Integer - assert n3.metadata.name == 'abc3' + assert "b" in n3.outputs + assert n3.outputs["b"].node_id == "n3" + assert n3.outputs["b"].var == "b" + assert n3.outputs["b"].sdk_node == n3 + assert n3.outputs["b"].sdk_type == _types.Types.Integer + assert n3.metadata.name == "abc3" assert n3.metadata.retries.retries == 3 - assert 'n2' in n3.upstream_node_ids + assert "n2" in n3.upstream_node_ids assert n2 in n3.upstream_nodes assert len(n3.upstream_nodes) == 1 assert len(n3.upstream_node_ids) == 1 @@ -125,19 +129,18 @@ def testy_test(wf_params, a, b): # Test left shift operator and late binding n4 = _nodes.SdkNode( - 'n4', + "n4", [], [ _literals.Binding( - 'a', - _interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), 3) + "a", _interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), 3), ) ], _core_workflow_models.NodeMetadata("abc4", _datetime.timedelta(minutes=15), _literals.RetryStrategy(3)), sdk_task=testy_test, sdk_workflow=None, sdk_launch_plan=None, - sdk_branch=None + sdk_branch=None, ) n4 << n3 @@ -146,19 +149,19 @@ def testy_test(wf_params, a, b): n4 << n3 << n2 << n n >> n2 >> n3 >> n4 - assert n4.id == 'n4' + assert n4.id == "n4" assert len(n4.inputs) == 1 - assert n4.inputs[0].var == 'a' + assert n4.inputs[0].var == "a" assert n4.inputs[0].binding.scalar.primitive.integer == 3 assert len(n4.outputs) == 1 - assert 'b' in n4.outputs - assert n4.outputs['b'].node_id == 'n4' - assert n4.outputs['b'].var == 'b' - assert n4.outputs['b'].sdk_node == n4 - assert n4.outputs['b'].sdk_type == _types.Types.Integer - assert n4.metadata.name == 'abc4' + assert "b" in n4.outputs + assert n4.outputs["b"].node_id == "n4" + assert n4.outputs["b"].var == "b" + assert n4.outputs["b"].sdk_node == n4 + assert n4.outputs["b"].sdk_type == _types.Types.Integer + assert n4.metadata.name == "abc4" assert n4.metadata.retries.retries == 3 - assert 'n3' in n4.upstream_node_ids + assert "n3" in n4.upstream_node_ids assert n3 in n4.upstream_nodes assert len(n4.upstream_nodes) == 1 assert len(n4.upstream_node_ids) == 1 @@ -166,9 +169,9 @@ def testy_test(wf_params, a, b): # Add another dependency n4 << n2 - assert 'n3' in n4.upstream_node_ids + assert "n3" in n4.upstream_node_ids assert n3 in n4.upstream_nodes - assert 'n2' in n4.upstream_node_ids + assert "n2" in n4.upstream_node_ids assert n2 in n4.upstream_nodes assert len(n4.upstream_nodes) == 2 assert len(n4.upstream_node_ids) == 2 @@ -181,25 +184,21 @@ def test_sdk_task_node(): def testy_test(wf_params, a, b): pass - testy_test._id = _identifier.Identifier(_identifier.ResourceType.TASK, 'project', 'domain', 'name', 'version') + testy_test._id = _identifier.Identifier(_identifier.ResourceType.TASK, "project", "domain", "name", "version") n = _component_nodes.SdkTaskNode(testy_test) - assert n.reference_id.project == 'project' - assert n.reference_id.domain == 'domain' - assert n.reference_id.name == 'name' - assert n.reference_id.version == 'version' + assert n.reference_id.project == "project" + assert n.reference_id.domain == "domain" + assert n.reference_id.name == "name" + assert n.reference_id.version == "version" # Test floating ID testy_test._id = _identifier.Identifier( - _identifier.ResourceType.TASK, - 'new_project', - 'new_domain', - 'new_name', - 'new_version' + _identifier.ResourceType.TASK, "new_project", "new_domain", "new_name", "new_version", ) - assert n.reference_id.project == 'new_project' - assert n.reference_id.domain == 'new_domain' - assert n.reference_id.name == 'new_name' - assert n.reference_id.version == 'new_version' + assert n.reference_id.project == "new_project" + assert n.reference_id.domain == "new_domain" + assert n.reference_id.name == "new_name" + assert n.reference_id.version == "new_version" def test_sdk_node_from_lp(): @@ -218,29 +217,28 @@ class test_workflow(object): lp = test_workflow.create_launch_plan() n1 = _nodes.SdkNode( - 'n1', + "n1", [], [ _literals.Binding( - 'a', - _interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), 3) + "a", _interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), 3), ) ], _core_workflow_models.NodeMetadata("abc", _datetime.timedelta(minutes=15), _literals.RetryStrategy(3)), sdk_launch_plan=lp, ) - assert n1.id == 'n1' + assert n1.id == "n1" assert len(n1.inputs) == 1 - assert n1.inputs[0].var == 'a' + assert n1.inputs[0].var == "a" assert n1.inputs[0].binding.scalar.primitive.integer == 3 assert len(n1.outputs) == 1 - assert 'b' in n1.outputs - assert n1.outputs['b'].node_id == 'n1' - assert n1.outputs['b'].var == 'b' - assert n1.outputs['b'].sdk_node == n1 - assert n1.outputs['b'].sdk_type == _types.Types.Integer - assert n1.metadata.name == 'abc' + assert "b" in n1.outputs + assert n1.outputs["b"].node_id == "n1" + assert n1.outputs["b"].var == "b" + assert n1.outputs["b"].sdk_node == n1 + assert n1.outputs["b"].sdk_type == _types.Types.Integer + assert n1.metadata.name == "abc" assert n1.metadata.retries.retries == 3 assert len(n1.upstream_nodes) == 0 assert len(n1.upstream_node_ids) == 0 @@ -262,25 +260,21 @@ class test_workflow(object): lp = test_workflow.create_launch_plan() - lp._id = _identifier.Identifier(_identifier.ResourceType.TASK, 'project', 'domain', 'name', 'version') + lp._id = _identifier.Identifier(_identifier.ResourceType.TASK, "project", "domain", "name", "version") n = _component_nodes.SdkWorkflowNode(sdk_launch_plan=lp) - assert n.launchplan_ref.project == 'project' - assert n.launchplan_ref.domain == 'domain' - assert n.launchplan_ref.name == 'name' - assert n.launchplan_ref.version == 'version' + assert n.launchplan_ref.project == "project" + assert n.launchplan_ref.domain == "domain" + assert n.launchplan_ref.name == "name" + assert n.launchplan_ref.version == "version" # Test floating ID lp._id = _identifier.Identifier( - _identifier.ResourceType.TASK, - 'new_project', - 'new_domain', - 'new_name', - 'new_version' + _identifier.ResourceType.TASK, "new_project", "new_domain", "new_name", "new_version", ) - assert n.launchplan_ref.project == 'new_project' - assert n.launchplan_ref.domain == 'new_domain' - assert n.launchplan_ref.name == 'new_name' - assert n.launchplan_ref.version == 'new_version' + assert n.launchplan_ref.project == "new_project" + assert n.launchplan_ref.domain == "new_domain" + assert n.launchplan_ref.name == "new_name" + assert n.launchplan_ref.version == "new_version" # If you specify both, you should get an exception with _pytest.raises(_system_exceptions.FlyteSystemException): diff --git a/tests/flytekit/unit/common_tests/test_notifications.py b/tests/flytekit/unit/common_tests/test_notifications.py index c57a37bef9..7124e82726 100644 --- a/tests/flytekit/unit/common_tests/test_notifications.py +++ b/tests/flytekit/unit/common_tests/test_notifications.py @@ -1,4 +1,5 @@ from __future__ import absolute_import + from flytekit.common import notifications as _notifications from flytekit.models.core import execution as _execution_model diff --git a/tests/flytekit/unit/common_tests/test_promise.py b/tests/flytekit/unit/common_tests/test_promise.py index a753787bd4..223373b80b 100644 --- a/tests/flytekit/unit/common_tests/test_promise.py +++ b/tests/flytekit/unit/common_tests/test_promise.py @@ -1,13 +1,15 @@ from __future__ import absolute_import + +import pytest + from flytekit.common import promise -from flytekit.common.types import primitives, base_sdk_types from flytekit.common.exceptions import user as _user_exceptions -import pytest +from flytekit.common.types import base_sdk_types, primitives def test_input(): - i = promise.Input('name', primitives.Integer, help="blah", default=None) - assert i.name == 'name' + i = promise.Input("name", primitives.Integer, help="blah", default=None) + assert i.name == "name" assert i.sdk_default is None assert i.default == base_sdk_types.Void() assert i.sdk_required is False @@ -15,9 +17,9 @@ def test_input(): assert i.var.description == "blah" assert i.sdk_type == primitives.Integer - i = promise.Input('name2', primitives.Integer, default=1) - assert i.name == 'name2' - assert i.sdk_default is 1 + i = promise.Input("name2", primitives.Integer, default=1) + assert i.name == "name2" + assert i.sdk_default == 1 assert i.default == primitives.Integer(1) assert i.required is None assert i.sdk_required is False @@ -26,4 +28,4 @@ def test_input(): assert i.sdk_type == primitives.Integer with pytest.raises(_user_exceptions.FlyteAssertion): - promise.Input('abc', primitives.Integer, required=True, default=1) + promise.Input("abc", primitives.Integer, required=True, default=1) diff --git a/tests/flytekit/unit/common_tests/test_schedules.py b/tests/flytekit/unit/common_tests/test_schedules.py index 6bdf494709..d860019576 100644 --- a/tests/flytekit/unit/common_tests/test_schedules.py +++ b/tests/flytekit/unit/common_tests/test_schedules.py @@ -1,9 +1,12 @@ from __future__ import absolute_import -from flytekit.common import schedules as _schedules -from flytekit.common.exceptions import user as _user_exceptions + import datetime as _datetime + import pytest as _pytest +from flytekit.common import schedules as _schedules +from flytekit.common.exceptions import user as _user_exceptions + def test_cron(): obj = _schedules.CronSchedule("* * ? * * *", kickoff_time_input_arg="abc") diff --git a/tests/flytekit/unit/common_tests/test_workflow.py b/tests/flytekit/unit/common_tests/test_workflow.py index 3a307a8ac2..53e1b14ade 100644 --- a/tests/flytekit/unit/common_tests/test_workflow.py +++ b/tests/flytekit/unit/common_tests/test_workflow.py @@ -1,20 +1,21 @@ from __future__ import absolute_import import pytest as _pytest +from flyteidl.admin import workflow_pb2 as _workflow_pb2 -from flytekit.common import workflow, constants, promise, nodes, interface +from flytekit.common import constants, interface, nodes, promise, workflow from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.types import primitives, containers +from flytekit.common.types import containers, primitives from flytekit.models import literals as _literals -from flytekit.models.core import workflow as _workflow_models, identifier as _identifier +from flytekit.models.core import identifier as _identifier +from flytekit.models.core import workflow as _workflow_models from flytekit.sdk import types as _types -from flytekit.sdk.tasks import python_task, inputs, outputs -from flyteidl.admin import workflow_pb2 as _workflow_pb2 +from flytekit.sdk.tasks import inputs, outputs, python_task def test_output(): - o = workflow.Output('name', 1, sdk_type=primitives.Integer, help="blah") - assert o.name == 'name' + o = workflow.Output("name", 1, sdk_type=primitives.Integer, help="blah") + assert o.name == "name" assert o.var.description == "blah" assert o.var.type == primitives.Integer.to_flyte_literal_type() assert o.binding_data.scalar.primitive.integer == 1 @@ -27,7 +28,7 @@ def test_workflow(): def my_task(wf_params, a, b): b.set(a + 1) - my_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, 'project', 'domain', 'my_task', 'version') + my_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, "project", "domain", "my_task", "version") @inputs(a=[primitives.Integer]) @outputs(b=[primitives.Integer]) @@ -35,70 +36,69 @@ def my_task(wf_params, a, b): def my_list_task(wf_params, a, b): b.set([v + 1 for v in a]) - my_list_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, 'project', 'domain', 'my_list_task', - 'version') + my_list_task._id = _identifier.Identifier( + _identifier.ResourceType.TASK, "project", "domain", "my_list_task", "version" + ) input_list = [ - promise.Input('input_1', primitives.Integer), - promise.Input('input_2', primitives.Integer, default=5, help='Not required.') + promise.Input("input_1", primitives.Integer), + promise.Input("input_2", primitives.Integer, default=5, help="Not required."), ] - n1 = my_task(a=input_list[0]).assign_id_and_return('n1') - n2 = my_task(a=input_list[1]).assign_id_and_return('n2') - n3 = my_task(a=100).assign_id_and_return('n3') - n4 = my_task(a=n1.outputs.b).assign_id_and_return('n4') - n5 = my_list_task(a=[input_list[0], input_list[1], n3.outputs.b, 100]).assign_id_and_return('n5') + n1 = my_task(a=input_list[0]).assign_id_and_return("n1") + n2 = my_task(a=input_list[1]).assign_id_and_return("n2") + n3 = my_task(a=100).assign_id_and_return("n3") + n4 = my_task(a=n1.outputs.b).assign_id_and_return("n4") + n5 = my_list_task(a=[input_list[0], input_list[1], n3.outputs.b, 100]).assign_id_and_return("n5") n6 = my_list_task(a=n5.outputs.b) n1 >> n6 nodes = [n1, n2, n3, n4, n5, n6] w = workflow.SdkWorkflow( - inputs=input_list, - outputs=[workflow.Output('a', n1.outputs.b, sdk_type=primitives.Integer)], - nodes=nodes + inputs=input_list, outputs=[workflow.Output("a", n1.outputs.b, sdk_type=primitives.Integer)], nodes=nodes, ) - assert w.interface.inputs['input_1'].type == primitives.Integer.to_flyte_literal_type() - assert w.interface.inputs['input_2'].type == primitives.Integer.to_flyte_literal_type() - assert w.nodes[0].inputs[0].var == 'a' + assert w.interface.inputs["input_1"].type == primitives.Integer.to_flyte_literal_type() + assert w.interface.inputs["input_2"].type == primitives.Integer.to_flyte_literal_type() + assert w.nodes[0].inputs[0].var == "a" assert w.nodes[0].inputs[0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID - assert w.nodes[0].inputs[0].binding.promise.var == 'input_1' + assert w.nodes[0].inputs[0].binding.promise.var == "input_1" assert w.nodes[1].inputs[0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID - assert w.nodes[1].inputs[0].binding.promise.var == 'input_2' + assert w.nodes[1].inputs[0].binding.promise.var == "input_2" assert w.nodes[2].inputs[0].binding.scalar.primitive.integer == 100 - assert w.nodes[3].inputs[0].var == 'a' + assert w.nodes[3].inputs[0].var == "a" assert w.nodes[3].inputs[0].binding.promise.node_id == n1.id # Test conversion to flyte_idl and back - w._id = _identifier.Identifier(_identifier.ResourceType.WORKFLOW, 'fake', 'faker', 'fakest', 'fakerest') + w._id = _identifier.Identifier(_identifier.ResourceType.WORKFLOW, "fake", "faker", "fakest", "fakerest") w = _workflow_models.WorkflowTemplate.from_flyte_idl(w.to_flyte_idl()) - assert w.interface.inputs['input_1'].type == primitives.Integer.to_flyte_literal_type() - assert w.interface.inputs['input_2'].type == primitives.Integer.to_flyte_literal_type() - assert w.nodes[0].inputs[0].var == 'a' + assert w.interface.inputs["input_1"].type == primitives.Integer.to_flyte_literal_type() + assert w.interface.inputs["input_2"].type == primitives.Integer.to_flyte_literal_type() + assert w.nodes[0].inputs[0].var == "a" assert w.nodes[0].inputs[0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID - assert w.nodes[0].inputs[0].binding.promise.var == 'input_1' + assert w.nodes[0].inputs[0].binding.promise.var == "input_1" assert w.nodes[1].inputs[0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID - assert w.nodes[1].inputs[0].binding.promise.var == 'input_2' + assert w.nodes[1].inputs[0].binding.promise.var == "input_2" assert w.nodes[2].inputs[0].binding.scalar.primitive.integer == 100 - assert w.nodes[3].inputs[0].var == 'a' + assert w.nodes[3].inputs[0].var == "a" assert w.nodes[3].inputs[0].binding.promise.node_id == n1.id - assert w.nodes[4].inputs[0].var == 'a' + assert w.nodes[4].inputs[0].var == "a" assert w.nodes[4].inputs[0].binding.collection.bindings[0].promise.node_id == constants.GLOBAL_INPUT_NODE_ID - assert w.nodes[4].inputs[0].binding.collection.bindings[0].promise.var == 'input_1' + assert w.nodes[4].inputs[0].binding.collection.bindings[0].promise.var == "input_1" assert w.nodes[4].inputs[0].binding.collection.bindings[1].promise.node_id == constants.GLOBAL_INPUT_NODE_ID - assert w.nodes[4].inputs[0].binding.collection.bindings[1].promise.var == 'input_2' + assert w.nodes[4].inputs[0].binding.collection.bindings[1].promise.var == "input_2" assert w.nodes[4].inputs[0].binding.collection.bindings[2].promise.node_id == n3.id - assert w.nodes[4].inputs[0].binding.collection.bindings[2].promise.var == 'b' + assert w.nodes[4].inputs[0].binding.collection.bindings[2].promise.var == "b" assert w.nodes[4].inputs[0].binding.collection.bindings[3].scalar.primitive.integer == 100 - assert w.nodes[5].inputs[0].var == 'a' + assert w.nodes[5].inputs[0].var == "a" assert w.nodes[5].inputs[0].binding.promise.node_id == n5.id - assert w.nodes[5].inputs[0].binding.promise.var == 'b' + assert w.nodes[5].inputs[0].binding.promise.var == "b" assert len(w.outputs) == 1 - assert w.outputs[0].var == 'a' - assert w.outputs[0].binding.promise.var == 'b' - assert w.outputs[0].binding.promise.node_id == 'n1' + assert w.outputs[0].var == "a" + assert w.outputs[0].binding.promise.var == "b" + assert w.outputs[0].binding.promise.node_id == "n1" # TODO: Test promotion of w -> SdkWorkflow @@ -109,7 +109,7 @@ def test_workflow_decorator(): def my_task(wf_params, a, b): b.set(a + 1) - my_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, 'propject', 'domain', 'my_task', 'version') + my_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, "propject", "domain", "my_task", "version") @inputs(a=[primitives.Integer]) @outputs(b=[primitives.Integer]) @@ -117,12 +117,13 @@ def my_task(wf_params, a, b): def my_list_task(wf_params, a, b): b.set([v + 1 for v in a]) - my_list_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, 'propject', 'domain', 'my_list_task', - 'version') + my_list_task._id = _identifier.Identifier( + _identifier.ResourceType.TASK, "propject", "domain", "my_list_task", "version" + ) class my_workflow(object): - input_1 = promise.Input('input_1', primitives.Integer) - input_2 = promise.Input('input_2', primitives.Integer, default=5, help='Not required.') + input_1 = promise.Input("input_1", primitives.Integer) + input_2 = promise.Input("input_2", primitives.Integer, default=5, help="Not required.") n1 = my_task(a=input_1) n2 = my_task(a=input_2) n3 = my_task(a=100) @@ -130,51 +131,55 @@ class my_workflow(object): n5 = my_list_task(a=[input_1, input_2, n3.outputs.b, 100]) n6 = my_list_task(a=n5.outputs.b) n1 >> n6 - a = workflow.Output('a', n1.outputs.b, sdk_type=primitives.Integer) + a = workflow.Output("a", n1.outputs.b, sdk_type=primitives.Integer) - w = workflow.build_sdk_workflow_from_metaclass(my_workflow, on_failure=_workflow_models.WorkflowMetadata.OnFailurePolicy.FAIL_AFTER_EXECUTABLE_NODES_COMPLETE) + w = workflow.build_sdk_workflow_from_metaclass( + my_workflow, on_failure=_workflow_models.WorkflowMetadata.OnFailurePolicy.FAIL_AFTER_EXECUTABLE_NODES_COMPLETE, + ) - assert w.interface.inputs['input_1'].type == primitives.Integer.to_flyte_literal_type() - assert w.interface.inputs['input_2'].type == primitives.Integer.to_flyte_literal_type() - assert w.nodes[0].inputs[0].var == 'a' + assert w.interface.inputs["input_1"].type == primitives.Integer.to_flyte_literal_type() + assert w.interface.inputs["input_2"].type == primitives.Integer.to_flyte_literal_type() + assert w.nodes[0].inputs[0].var == "a" assert w.nodes[0].inputs[0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID - assert w.nodes[0].inputs[0].binding.promise.var == 'input_1' + assert w.nodes[0].inputs[0].binding.promise.var == "input_1" assert w.nodes[1].inputs[0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID - assert w.nodes[1].inputs[0].binding.promise.var == 'input_2' + assert w.nodes[1].inputs[0].binding.promise.var == "input_2" assert w.nodes[2].inputs[0].binding.scalar.primitive.integer == 100 - assert w.nodes[3].inputs[0].var == 'a' - assert w.nodes[3].inputs[0].binding.promise.node_id == 'n1' + assert w.nodes[3].inputs[0].var == "a" + assert w.nodes[3].inputs[0].binding.promise.node_id == "n1" # Test conversion to flyte_idl and back - w._id = _identifier.Identifier(_identifier.ResourceType.WORKFLOW, 'fake', 'faker', 'fakest', 'fakerest') + w._id = _identifier.Identifier(_identifier.ResourceType.WORKFLOW, "fake", "faker", "fakest", "fakerest") w = _workflow_models.WorkflowTemplate.from_flyte_idl(w.to_flyte_idl()) - assert w.interface.inputs['input_1'].type == primitives.Integer.to_flyte_literal_type() - assert w.interface.inputs['input_2'].type == primitives.Integer.to_flyte_literal_type() - assert w.nodes[0].inputs[0].var == 'a' + assert w.interface.inputs["input_1"].type == primitives.Integer.to_flyte_literal_type() + assert w.interface.inputs["input_2"].type == primitives.Integer.to_flyte_literal_type() + assert w.nodes[0].inputs[0].var == "a" assert w.nodes[0].inputs[0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID - assert w.nodes[0].inputs[0].binding.promise.var == 'input_1' + assert w.nodes[0].inputs[0].binding.promise.var == "input_1" assert w.nodes[1].inputs[0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID - assert w.nodes[1].inputs[0].binding.promise.var == 'input_2' + assert w.nodes[1].inputs[0].binding.promise.var == "input_2" assert w.nodes[2].inputs[0].binding.scalar.primitive.integer == 100 - assert w.nodes[3].inputs[0].var == 'a' - assert w.nodes[3].inputs[0].binding.promise.node_id == 'n1' - assert w.nodes[4].inputs[0].var == 'a' + assert w.nodes[3].inputs[0].var == "a" + assert w.nodes[3].inputs[0].binding.promise.node_id == "n1" + assert w.nodes[4].inputs[0].var == "a" assert w.nodes[4].inputs[0].binding.collection.bindings[0].promise.node_id == constants.GLOBAL_INPUT_NODE_ID - assert w.nodes[4].inputs[0].binding.collection.bindings[0].promise.var == 'input_1' + assert w.nodes[4].inputs[0].binding.collection.bindings[0].promise.var == "input_1" assert w.nodes[4].inputs[0].binding.collection.bindings[1].promise.node_id == constants.GLOBAL_INPUT_NODE_ID - assert w.nodes[4].inputs[0].binding.collection.bindings[1].promise.var == 'input_2' - assert w.nodes[4].inputs[0].binding.collection.bindings[2].promise.node_id == 'n3' - assert w.nodes[4].inputs[0].binding.collection.bindings[2].promise.var == 'b' + assert w.nodes[4].inputs[0].binding.collection.bindings[1].promise.var == "input_2" + assert w.nodes[4].inputs[0].binding.collection.bindings[2].promise.node_id == "n3" + assert w.nodes[4].inputs[0].binding.collection.bindings[2].promise.var == "b" assert w.nodes[4].inputs[0].binding.collection.bindings[3].scalar.primitive.integer == 100 - assert w.nodes[5].inputs[0].var == 'a' - assert w.nodes[5].inputs[0].binding.promise.node_id == 'n5' - assert w.nodes[5].inputs[0].binding.promise.var == 'b' + assert w.nodes[5].inputs[0].var == "a" + assert w.nodes[5].inputs[0].binding.promise.node_id == "n5" + assert w.nodes[5].inputs[0].binding.promise.var == "b" assert len(w.outputs) == 1 - assert w.outputs[0].var == 'a' - assert w.outputs[0].binding.promise.var == 'b' - assert w.outputs[0].binding.promise.node_id == 'n1' - assert w.metadata.on_failure == _workflow_models.WorkflowMetadata.OnFailurePolicy.FAIL_AFTER_EXECUTABLE_NODES_COMPLETE + assert w.outputs[0].var == "a" + assert w.outputs[0].binding.promise.var == "b" + assert w.outputs[0].binding.promise.node_id == "n1" + assert ( + w.metadata.on_failure == _workflow_models.WorkflowMetadata.OnFailurePolicy.FAIL_AFTER_EXECUTABLE_NODES_COMPLETE + ) # TODO: Test promotion of w -> SdkWorkflow @@ -185,7 +190,7 @@ def test_workflow_node(): def my_task(wf_params, a, b): b.set(a + 1) - my_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, 'project', 'domain', 'my_task', 'version') + my_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, "project", "domain", "my_task", "version") @inputs(a=[primitives.Integer]) @outputs(b=[primitives.Integer]) @@ -193,30 +198,29 @@ def my_task(wf_params, a, b): def my_list_task(wf_params, a, b): b.set([v + 1 for v in a]) - my_list_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, 'project', 'domain', 'my_list_task', - 'version') + my_list_task._id = _identifier.Identifier( + _identifier.ResourceType.TASK, "project", "domain", "my_list_task", "version" + ) input_list = [ - promise.Input('required', primitives.Integer), - promise.Input('not_required', primitives.Integer, default=5, help='Not required.') + promise.Input("required", primitives.Integer), + promise.Input("not_required", primitives.Integer, default=5, help="Not required."), ] - n1 = my_task(a=input_list[0]).assign_id_and_return('n1') - n2 = my_task(a=input_list[1]).assign_id_and_return('n2') - n3 = my_task(a=100).assign_id_and_return('n3') - n4 = my_task(a=n1.outputs.b).assign_id_and_return('n4') - n5 = my_list_task(a=[input_list[0], input_list[1], n3.outputs.b, 100]).assign_id_and_return('n5') + n1 = my_task(a=input_list[0]).assign_id_and_return("n1") + n2 = my_task(a=input_list[1]).assign_id_and_return("n2") + n3 = my_task(a=100).assign_id_and_return("n3") + n4 = my_task(a=n1.outputs.b).assign_id_and_return("n4") + n5 = my_list_task(a=[input_list[0], input_list[1], n3.outputs.b, 100]).assign_id_and_return("n5") n6 = my_list_task(a=n5.outputs.b) nodes = [n1, n2, n3, n4, n5, n6] wf_out = [ workflow.Output( - 'nested_out', - [n5.outputs.b, n6.outputs.b, [n1.outputs.b, n2.outputs.b]], - sdk_type=[[primitives.Integer]] + "nested_out", [n5.outputs.b, n6.outputs.b, [n1.outputs.b, n2.outputs.b]], sdk_type=[[primitives.Integer]], ), - workflow.Output('scalar_out', n1.outputs.b, sdk_type=primitives.Integer) + workflow.Output("scalar_out", n1.outputs.b, sdk_type=primitives.Integer), ] w = workflow.SdkWorkflow(inputs=input_list, outputs=wf_out, nodes=nodes) @@ -231,7 +235,7 @@ def my_list_task(wf_params, a, b): # Test that type checking works with _pytest.raises(_user_exceptions.FlyteTypeException): - w(required='abc', not_required=1) + w(required="abc", not_required=1) # Test that bad arg name is detected with _pytest.raises(_user_exceptions.FlyteAssertion): @@ -239,33 +243,35 @@ def my_list_task(wf_params, a, b): # Test default input is accounted for n = w(required=10) - assert n.inputs[0].var == 'not_required' + assert n.inputs[0].var == "not_required" assert n.inputs[0].binding.scalar.primitive.integer == 5 - assert n.inputs[1].var == 'required' + assert n.inputs[1].var == "required" assert n.inputs[1].binding.scalar.primitive.integer == 10 # Test default input is overridden n = w(required=10, not_required=50) - assert n.inputs[0].var == 'not_required' + assert n.inputs[0].var == "not_required" assert n.inputs[0].binding.scalar.primitive.integer == 50 - assert n.inputs[1].var == 'required' + assert n.inputs[1].var == "required" assert n.inputs[1].binding.scalar.primitive.integer == 10 # Test that workflow is saved in the node - w._id = 'fake' - assert n.workflow_node.sub_workflow_ref == 'fake' + w._id = "fake" + assert n.workflow_node.sub_workflow_ref == "fake" w._id = None # Test that outputs are promised - n.assign_id_and_return('node-id*') # dns'ified - assert n.outputs['scalar_out'].sdk_type.to_flyte_literal_type() == primitives.Integer.to_flyte_literal_type() - assert n.outputs['scalar_out'].var == 'scalar_out' - assert n.outputs['scalar_out'].node_id == 'node-id' - - assert n.outputs['nested_out'].sdk_type.to_flyte_literal_type() == \ - containers.List(containers.List(primitives.Integer)).to_flyte_literal_type() - assert n.outputs['nested_out'].var == 'nested_out' - assert n.outputs['nested_out'].node_id == 'node-id' + n.assign_id_and_return("node-id*") # dns'ified + assert n.outputs["scalar_out"].sdk_type.to_flyte_literal_type() == primitives.Integer.to_flyte_literal_type() + assert n.outputs["scalar_out"].var == "scalar_out" + assert n.outputs["scalar_out"].node_id == "node-id" + + assert ( + n.outputs["nested_out"].sdk_type.to_flyte_literal_type() + == containers.List(containers.List(primitives.Integer)).to_flyte_literal_type() + ) + assert n.outputs["nested_out"].var == "nested_out" + assert n.outputs["nested_out"].node_id == "node-id" def test_non_system_nodes(): @@ -275,31 +281,30 @@ def test_non_system_nodes(): def my_task(wf_params, a, b): b.set(a + 1) - my_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, 'project', 'domain', 'my_task', 'version') + my_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, "project", "domain", "my_task", "version") - required_input = promise.Input('required', primitives.Integer) + required_input = promise.Input("required", primitives.Integer) - n1 = my_task(a=required_input).assign_id_and_return('n1') + n1 = my_task(a=required_input).assign_id_and_return("n1") n_start = nodes.SdkNode( - 'start-node', + "start-node", [], [ _literals.Binding( - 'a', - interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), 3) + "a", interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), 3), ) ], None, sdk_task=my_task, sdk_workflow=None, sdk_launch_plan=None, - sdk_branch=None + sdk_branch=None, ) non_system_nodes = workflow.SdkWorkflow.get_non_system_nodes([n1, n_start]) assert len(non_system_nodes) == 1 - assert non_system_nodes[0].id == 'n1' + assert non_system_nodes[0].id == "n1" def test_workflow_serialization(): @@ -309,7 +314,7 @@ def test_workflow_serialization(): def my_task(wf_params, a, b): b.set(a + 1) - my_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, 'project', 'domain', 'my_task', 'version') + my_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, "project", "domain", "my_task", "version") @inputs(a=[primitives.Integer]) @outputs(b=[primitives.Integer]) @@ -317,30 +322,29 @@ def my_task(wf_params, a, b): def my_list_task(wf_params, a, b): b.set([v + 1 for v in a]) - my_list_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, 'project', 'domain', 'my_list_task', - 'version') + my_list_task._id = _identifier.Identifier( + _identifier.ResourceType.TASK, "project", "domain", "my_list_task", "version" + ) input_list = [ - promise.Input('required', primitives.Integer), - promise.Input('not_required', primitives.Integer, default=5, help='Not required.') + promise.Input("required", primitives.Integer), + promise.Input("not_required", primitives.Integer, default=5, help="Not required."), ] - n1 = my_task(a=input_list[0]).assign_id_and_return('n1') - n2 = my_task(a=input_list[1]).assign_id_and_return('n2') - n3 = my_task(a=100).assign_id_and_return('n3') - n4 = my_task(a=n1.outputs.b).assign_id_and_return('n4') - n5 = my_list_task(a=[input_list[0], input_list[1], n3.outputs.b, 100]).assign_id_and_return('n5') + n1 = my_task(a=input_list[0]).assign_id_and_return("n1") + n2 = my_task(a=input_list[1]).assign_id_and_return("n2") + n3 = my_task(a=100).assign_id_and_return("n3") + n4 = my_task(a=n1.outputs.b).assign_id_and_return("n4") + n5 = my_list_task(a=[input_list[0], input_list[1], n3.outputs.b, 100]).assign_id_and_return("n5") n6 = my_list_task(a=n5.outputs.b) nodes = [n1, n2, n3, n4, n5, n6] wf_out = [ workflow.Output( - 'nested_out', - [n5.outputs.b, n6.outputs.b, [n1.outputs.b, n2.outputs.b]], - sdk_type=[[primitives.Integer]] + "nested_out", [n5.outputs.b, n6.outputs.b, [n1.outputs.b, n2.outputs.b]], sdk_type=[[primitives.Integer]], ), - workflow.Output('scalar_out', n1.outputs.b, sdk_type=primitives.Integer) + workflow.Output("scalar_out", n1.outputs.b, sdk_type=primitives.Integer), ] w = workflow.SdkWorkflow(inputs=input_list, outputs=wf_out, nodes=nodes) diff --git a/tests/flytekit/unit/common_tests/test_workflow_promote.py b/tests/flytekit/unit/common_tests/test_workflow_promote.py index 87c5eba0ea..22d9f53d9a 100644 --- a/tests/flytekit/unit/common_tests/test_workflow_promote.py +++ b/tests/flytekit/unit/common_tests/test_workflow_promote.py @@ -1,24 +1,27 @@ from __future__ import absolute_import from datetime import timedelta +from os import path as _path +from flyteidl.core import compiler_pb2 as _compiler_pb2 +from flyteidl.core import workflow_pb2 as _workflow_pb2 from mock import patch as _patch -from os import path as _path from flytekit.common import workflow as _workflow_common from flytekit.common.tasks import task as _task -from flytekit.models import interface as _interface, \ - literals as _literals, types as _types, task as _task_model -from flytekit.models.core import workflow as _workflow_model, identifier as _identifier, compiler as _compiler_model +from flytekit.models import interface as _interface +from flytekit.models import literals as _literals +from flytekit.models import task as _task_model +from flytekit.models import types as _types +from flytekit.models.core import compiler as _compiler_model +from flytekit.models.core import identifier as _identifier +from flytekit.models.core import workflow as _workflow_model from flytekit.sdk import tasks as _sdk_tasks from flytekit.sdk import workflow as _sdk_workflow -from flytekit.sdk.types import Types as _Types -from flyteidl.core import compiler_pb2 as _compiler_pb2, workflow_pb2 as _workflow_pb2 -from flytekit.sdk.types import Types -from flytekit.sdk.workflow import workflow_class, Input, Output from flytekit.sdk.tasks import inputs, outputs, python_task from flytekit.sdk.types import Types -from flytekit.sdk.workflow import workflow_class, Input, Output +from flytekit.sdk.types import Types as _Types +from flytekit.sdk.workflow import Input, Output, workflow_class def get_sample_node_metadata(node_id): @@ -27,11 +30,7 @@ def get_sample_node_metadata(node_id): :rtype: flytekit.models.core.workflow.NodeMetadata """ - return _workflow_model.NodeMetadata( - name=node_id, - timeout=timedelta(seconds=10), - retries=_literals.RetryStrategy(0) - ) + return _workflow_model.NodeMetadata(name=node_id, timeout=timedelta(seconds=10), retries=_literals.RetryStrategy(0)) def get_sample_container(): @@ -42,12 +41,7 @@ def get_sample_container(): resources = _task_model.Resources(requests=[cpu_resource], limits=[cpu_resource]) return _task_model.Container( - "my_image", - ["this", "is", "a", "cmd"], - ["this", "is", "an", "arg"], - resources, - {}, - {} + "my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], resources, {}, {}, ) @@ -62,7 +56,7 @@ def get_sample_task_metadata(): _literals.RetryStrategy(3), True, "0.1.1b0", - "This is deprecated!" + "This is deprecated!", ) @@ -133,23 +127,24 @@ class TestPromoteExampleWf(object): int_type = _types.LiteralType(_types.SimpleType.INTEGER) task_interface = _interface.TypedInterface( # inputs - {'a': _interface.Variable(int_type, "description1")}, + {"a": _interface.Variable(int_type, "description1")}, # outputs - { - 'b': _interface.Variable(int_type, "description2"), - 'c': _interface.Variable(int_type, "description3") - } + {"b": _interface.Variable(int_type, "description2"), "c": _interface.Variable(int_type, "description3")}, ) # Since the promotion of a workflow requires retrieving the task from Admin, we mock the SdkTask to return task_template = _task_model.TaskTemplate( - _identifier.Identifier(_identifier.ResourceType.TASK, "project", "domain", - "tests.flytekit.unit.common_tests.test_workflow_promote.demo_task_for_promote", - "version"), + _identifier.Identifier( + _identifier.ResourceType.TASK, + "project", + "domain", + "tests.flytekit.unit.common_tests.test_workflow_promote.demo_task_for_promote", + "version", + ), "python_container", get_sample_task_metadata(), task_interface, custom={}, - container=get_sample_container() + container=get_sample_container(), ) sdk_promoted_task = _task.SdkTask.promote_from_model(task_template) mock_task_fetch.return_value = sdk_promoted_task diff --git a/tests/flytekit/unit/common_tests/types/impl/test_blobs.py b/tests/flytekit/unit/common_tests/types/impl/test_blobs.py index bc22e61aa1..aa9931df8e 100644 --- a/tests/flytekit/unit/common_tests/types/impl/test_blobs.py +++ b/tests/flytekit/unit/common_tests/types/impl/test_blobs.py @@ -1,11 +1,14 @@ from __future__ import absolute_import + +import os + +import pytest + from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.types.impl import blobs from flytekit.common.utils import AutoDeletingTempDir from flytekit.models.core import types as _core_types from flytekit.sdk import test_utils -import pytest -import os def test_blob(): @@ -19,18 +22,18 @@ def test_blob(): def test_blob_from_python_std(): with test_utils.LocalTestFileSystem() as t: - with AutoDeletingTempDir('test') as wd: + with AutoDeletingTempDir("test") as wd: tmp_name = wd.get_named_tempfile("from_python_std") - with open(tmp_name, 'wb') as w: - w.write("hello hello".encode('utf-8')) + with open(tmp_name, "wb") as w: + w.write("hello hello".encode("utf-8")) b = blobs.Blob.from_python_std(tmp_name) assert b.mode == "wb" assert b.metadata.type.format == "" assert b.metadata.type.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE assert b.remote_location.startswith(t.name) assert b.local_path == tmp_name - with open(b.remote_location, 'rb') as r: - assert r.read() == "hello hello".encode('utf-8') + with open(b.remote_location, "rb") as r: + assert r.read() == "hello hello".encode("utf-8") b = blobs.Blob("/tmp/fake") b2 = blobs.Blob.from_python_std(b) @@ -42,84 +45,84 @@ def test_blob_from_python_std(): def test_blob_create_at(): with test_utils.LocalTestFileSystem() as t: - with AutoDeletingTempDir('test') as wd: - tmp_name = wd.get_named_tempfile('tmp') + with AutoDeletingTempDir("test") as wd: + tmp_name = wd.get_named_tempfile("tmp") b = blobs.Blob.create_at_known_location(tmp_name) assert b.local_path is None assert b.remote_location == tmp_name - assert b.mode == 'wb' + assert b.mode == "wb" assert b.metadata.type.format == "" assert b.metadata.type.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE with b as w: - w.write("hello hello".encode('utf-8')) + w.write("hello hello".encode("utf-8")) assert b.local_path.startswith(t.name) - with open(tmp_name, 'rb') as r: - assert r.read() == "hello hello".encode('utf-8') + with open(tmp_name, "rb") as r: + assert r.read() == "hello hello".encode("utf-8") def test_blob_fetch_managed(): - with AutoDeletingTempDir('test') as wd: + with AutoDeletingTempDir("test") as wd: with test_utils.LocalTestFileSystem() as t: - tmp_name = wd.get_named_tempfile('tmp') - with open(tmp_name, 'wb') as w: - w.write("hello".encode('utf-8')) + tmp_name = wd.get_named_tempfile("tmp") + with open(tmp_name, "wb") as w: + w.write("hello".encode("utf-8")) b = blobs.Blob.fetch(tmp_name) assert b.local_path.startswith(t.name) assert b.remote_location == tmp_name - assert b.mode == 'rb' + assert b.mode == "rb" assert b.metadata.type.format == "" assert b.metadata.type.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE with b as r: - assert r.read() == "hello".encode('utf-8') + assert r.read() == "hello".encode("utf-8") with pytest.raises(_user_exceptions.FlyteAssertion): blobs.Blob.fetch(tmp_name, local_path=b.local_path) - with open(tmp_name, 'wb') as w: - w.write("bye".encode('utf-8')) + with open(tmp_name, "wb") as w: + w.write("bye".encode("utf-8")) b2 = blobs.Blob.fetch(tmp_name, local_path=b.local_path, overwrite=True) with b2 as r: - assert r.read() == "bye".encode('utf-8') + assert r.read() == "bye".encode("utf-8") with pytest.raises(_user_exceptions.FlyteAssertion): blobs.Blob.fetch(tmp_name) def test_blob_fetch_unmanaged(): - with AutoDeletingTempDir('test') as wd: - with AutoDeletingTempDir('test2') as t: - tmp_name = wd.get_named_tempfile('source') - tmp_sink = t.get_named_tempfile('sink') - with open(tmp_name, 'wb') as w: - w.write("hello".encode('utf-8')) + with AutoDeletingTempDir("test") as wd: + with AutoDeletingTempDir("test2") as t: + tmp_name = wd.get_named_tempfile("source") + tmp_sink = t.get_named_tempfile("sink") + with open(tmp_name, "wb") as w: + w.write("hello".encode("utf-8")) b = blobs.Blob.fetch(tmp_name, local_path=tmp_sink) assert b.local_path == tmp_sink assert b.remote_location == tmp_name - assert b.mode == 'rb' + assert b.mode == "rb" assert b.metadata.type.format == "" assert b.metadata.type.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE with b as r: - assert r.read() == "hello".encode('utf-8') + assert r.read() == "hello".encode("utf-8") with pytest.raises(_user_exceptions.FlyteAssertion): blobs.Blob.fetch(tmp_name, local_path=tmp_sink) - with open(tmp_name, 'wb') as w: - w.write("bye".encode('utf-8')) + with open(tmp_name, "wb") as w: + w.write("bye".encode("utf-8")) b2 = blobs.Blob.fetch(tmp_name, local_path=tmp_sink, overwrite=True) with b2 as r: - assert r.read() == "bye".encode('utf-8') + assert r.read() == "bye".encode("utf-8") def test_blob_double_enter(): with test_utils.LocalTestFileSystem(): - with AutoDeletingTempDir('test') as wd: - b = blobs.Blob(wd.get_named_tempfile("sink"), mode='wb') + with AutoDeletingTempDir("test") as wd: + b = blobs.Blob(wd.get_named_tempfile("sink"), mode="wb") with b: with pytest.raises(_user_exceptions.FlyteAssertion): with b: @@ -127,33 +130,33 @@ def test_blob_double_enter(): def test_blob_download_managed(): - with AutoDeletingTempDir('test') as wd: + with AutoDeletingTempDir("test") as wd: with test_utils.LocalTestFileSystem() as t: - tmp_name = wd.get_named_tempfile('tmp') - with open(tmp_name, 'wb') as w: - w.write("hello".encode('utf-8')) + tmp_name = wd.get_named_tempfile("tmp") + with open(tmp_name, "wb") as w: + w.write("hello".encode("utf-8")) b = blobs.Blob(tmp_name) b.download() assert b.local_path.startswith(t.name) assert b.remote_location == tmp_name - assert b.mode == 'rb' + assert b.mode == "rb" assert b.metadata.type.format == "" assert b.metadata.type.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE with b as r: - assert r.read() == "hello".encode('utf-8') + assert r.read() == "hello".encode("utf-8") b2 = blobs.Blob(tmp_name) with pytest.raises(_user_exceptions.FlyteAssertion): b2.download(b.local_path) - with open(tmp_name, 'wb') as w: - w.write("bye".encode('utf-8')) + with open(tmp_name, "wb") as w: + w.write("bye".encode("utf-8")) b2 = blobs.Blob(tmp_name) b2.download(local_path=b.local_path, overwrite=True) with b2 as r: - assert r.read() == "bye".encode('utf-8') + assert r.read() == "bye".encode("utf-8") b = blobs.Blob(tmp_name) with pytest.raises(_user_exceptions.FlyteAssertion): @@ -161,38 +164,38 @@ def test_blob_download_managed(): def test_blob_download_unmanaged(): - with AutoDeletingTempDir('test') as wd: - with AutoDeletingTempDir('test2') as t: - tmp_name = wd.get_named_tempfile('source') - tmp_sink = t.get_named_tempfile('sink') - with open(tmp_name, 'wb') as w: - w.write("hello".encode('utf-8')) + with AutoDeletingTempDir("test") as wd: + with AutoDeletingTempDir("test2") as t: + tmp_name = wd.get_named_tempfile("source") + tmp_sink = t.get_named_tempfile("sink") + with open(tmp_name, "wb") as w: + w.write("hello".encode("utf-8")) b = blobs.Blob(tmp_name) b.download(tmp_sink) assert b.local_path == tmp_sink assert b.remote_location == tmp_name - assert b.mode == 'rb' + assert b.mode == "rb" assert b.metadata.type.format == "" assert b.metadata.type.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE with b as r: - assert r.read() == "hello".encode('utf-8') + assert r.read() == "hello".encode("utf-8") b = blobs.Blob(tmp_name) with pytest.raises(_user_exceptions.FlyteAssertion): b.download(tmp_sink) - with open(tmp_name, 'wb') as w: - w.write("bye".encode('utf-8')) + with open(tmp_name, "wb") as w: + w.write("bye".encode("utf-8")) b2 = blobs.Blob(tmp_name) b2.download(tmp_sink, overwrite=True) with b2 as r: - assert r.read() == "bye".encode('utf-8') + assert r.read() == "bye".encode("utf-8") def test_multipart_blob(): - b = blobs.MultiPartBlob("/tmp/fake", mode='w', format='csv') + b = blobs.MultiPartBlob("/tmp/fake", mode="w", format="csv") assert b.remote_location == "/tmp/fake/" assert b.local_path is None assert b.mode == "w" @@ -202,19 +205,19 @@ def test_multipart_blob(): def _generate_multipart_blob_data(tmp_dir): n = tmp_dir.get_named_tempfile("0") - with open(n, 'wb') as w: - w.write("part0".encode('utf-8')) + with open(n, "wb") as w: + w.write("part0".encode("utf-8")) n = tmp_dir.get_named_tempfile("1") - with open(n, 'wb') as w: - w.write("part1".encode('utf-8')) + with open(n, "wb") as w: + w.write("part1".encode("utf-8")) n = tmp_dir.get_named_tempfile("2") - with open(n, 'wb') as w: - w.write("part2".encode('utf-8')) + with open(n, "wb") as w: + w.write("part2".encode("utf-8")) def test_multipart_blob_from_python_std(): with test_utils.LocalTestFileSystem() as t: - with AutoDeletingTempDir('test') as wd: + with AutoDeletingTempDir("test") as wd: _generate_multipart_blob_data(wd) b = blobs.MultiPartBlob.from_python_std(wd.name) assert b.mode == "wb" @@ -222,12 +225,12 @@ def test_multipart_blob_from_python_std(): assert b.metadata.type.dimensionality == _core_types.BlobType.BlobDimensionality.MULTIPART assert b.remote_location.startswith(t.name) assert b.local_path == wd.name - with open(os.path.join(b.remote_location, '0'), 'rb') as r: - assert r.read() == "part0".encode('utf-8') - with open(os.path.join(b.remote_location, '1'), 'rb') as r: - assert r.read() == "part1".encode('utf-8') - with open(os.path.join(b.remote_location, '2'), 'rb') as r: - assert r.read() == "part2".encode('utf-8') + with open(os.path.join(b.remote_location, "0"), "rb") as r: + assert r.read() == "part0".encode("utf-8") + with open(os.path.join(b.remote_location, "1"), "rb") as r: + assert r.read() == "part1".encode("utf-8") + with open(os.path.join(b.remote_location, "2"), "rb") as r: + assert r.read() == "part2".encode("utf-8") b = blobs.MultiPartBlob("/tmp/fake/") b2 = blobs.MultiPartBlob.from_python_std(b) @@ -239,88 +242,88 @@ def test_multipart_blob_from_python_std(): def test_multipart_blob_create_at(): with test_utils.LocalTestFileSystem(): - with AutoDeletingTempDir('test') as wd: + with AutoDeletingTempDir("test") as wd: b = blobs.MultiPartBlob.create_at_known_location(wd.name) assert b.local_path is None assert b.remote_location == wd.name + "/" - assert b.mode == 'wb' + assert b.mode == "wb" assert b.metadata.type.format == "" assert b.metadata.type.dimensionality == _core_types.BlobType.BlobDimensionality.MULTIPART - with b.create_part('0') as w: - w.write("part0".encode('utf-8')) - with b.create_part('1') as w: - w.write("part1".encode('utf-8')) - with b.create_part('2') as w: - w.write("part2".encode('utf-8')) + with b.create_part("0") as w: + w.write("part0".encode("utf-8")) + with b.create_part("1") as w: + w.write("part1".encode("utf-8")) + with b.create_part("2") as w: + w.write("part2".encode("utf-8")) - with open(os.path.join(wd.name, '0'), 'rb') as r: - assert r.read() == "part0".encode('utf-8') - with open(os.path.join(wd.name, '1'), 'rb') as r: - assert r.read() == "part1".encode('utf-8') - with open(os.path.join(wd.name, '2'), 'rb') as r: - assert r.read() == "part2".encode('utf-8') + with open(os.path.join(wd.name, "0"), "rb") as r: + assert r.read() == "part0".encode("utf-8") + with open(os.path.join(wd.name, "1"), "rb") as r: + assert r.read() == "part1".encode("utf-8") + with open(os.path.join(wd.name, "2"), "rb") as r: + assert r.read() == "part2".encode("utf-8") def test_multipart_blob_fetch_managed(): - with AutoDeletingTempDir('test') as wd: + with AutoDeletingTempDir("test") as wd: with test_utils.LocalTestFileSystem() as t: _generate_multipart_blob_data(wd) b = blobs.MultiPartBlob.fetch(wd.name) assert b.local_path.startswith(t.name) assert b.remote_location == wd.name + "/" - assert b.mode == 'rb' + assert b.mode == "rb" assert b.metadata.type.format == "" assert b.metadata.type.dimensionality == _core_types.BlobType.BlobDimensionality.MULTIPART with b as r: - assert r[0].read() == "part0".encode('utf-8') - assert r[1].read() == "part1".encode('utf-8') - assert r[2].read() == "part2".encode('utf-8') + assert r[0].read() == "part0".encode("utf-8") + assert r[1].read() == "part1".encode("utf-8") + assert r[2].read() == "part2".encode("utf-8") with pytest.raises(_user_exceptions.FlyteAssertion): blobs.MultiPartBlob.fetch(wd.name, local_path=b.local_path) - with open(os.path.join(wd.name, "0"), 'wb') as w: - w.write("bye".encode('utf-8')) + with open(os.path.join(wd.name, "0"), "wb") as w: + w.write("bye".encode("utf-8")) b2 = blobs.MultiPartBlob.fetch(wd.name, local_path=b.local_path, overwrite=True) with b2 as r: - assert r[0].read() == "bye".encode('utf-8') - assert r[1].read() == "part1".encode('utf-8') - assert r[2].read() == "part2".encode('utf-8') + assert r[0].read() == "bye".encode("utf-8") + assert r[1].read() == "part1".encode("utf-8") + assert r[2].read() == "part2".encode("utf-8") with pytest.raises(_user_exceptions.FlyteAssertion): blobs.Blob.fetch(wd.name) def test_multipart_blob_fetch_unmanaged(): - with AutoDeletingTempDir('test') as wd: - with AutoDeletingTempDir('test2') as t: + with AutoDeletingTempDir("test") as wd: + with AutoDeletingTempDir("test2") as t: _generate_multipart_blob_data(wd) - tmp_sink = t.get_named_tempfile('sink') + tmp_sink = t.get_named_tempfile("sink") b = blobs.MultiPartBlob.fetch(wd.name, local_path=tmp_sink) assert b.local_path == tmp_sink assert b.remote_location == wd.name + "/" - assert b.mode == 'rb' + assert b.mode == "rb" assert b.metadata.type.format == "" assert b.metadata.type.dimensionality == _core_types.BlobType.BlobDimensionality.MULTIPART with b as r: - assert r[0].read() == "part0".encode('utf-8') - assert r[1].read() == "part1".encode('utf-8') - assert r[2].read() == "part2".encode('utf-8') + assert r[0].read() == "part0".encode("utf-8") + assert r[1].read() == "part1".encode("utf-8") + assert r[2].read() == "part2".encode("utf-8") with pytest.raises(_user_exceptions.FlyteAssertion): blobs.MultiPartBlob.fetch(wd.name, local_path=tmp_sink) - with open(os.path.join(wd.name, "0"), 'wb') as w: - w.write("bye".encode('utf-8')) + with open(os.path.join(wd.name, "0"), "wb") as w: + w.write("bye".encode("utf-8")) b2 = blobs.MultiPartBlob.fetch(wd.name, local_path=tmp_sink, overwrite=True) with b2 as r: - assert r[0].read() == "bye".encode('utf-8') - assert r[1].read() == "part1".encode('utf-8') - assert r[2].read() == "part2".encode('utf-8') + assert r[0].read() == "bye".encode("utf-8") + assert r[1].read() == "part1".encode("utf-8") + assert r[2].read() == "part2".encode("utf-8") def test_multipart_blob_no_enter_on_write(): diff --git a/tests/flytekit/unit/common_tests/types/impl/test_schema.py b/tests/flytekit/unit/common_tests/types/impl/test_schema.py index aa2b253b43..3576676b3e 100644 --- a/tests/flytekit/unit/common_tests/types/impl/test_schema.py +++ b/tests/flytekit/unit/common_tests/types/impl/test_schema.py @@ -1,64 +1,68 @@ from __future__ import absolute_import +import collections as _collections import datetime as _datetime import os as _os import uuid as _uuid -import collections as _collections import pandas as _pd import pytest as _pytest import six.moves as _six_moves from flytekit.common import utils as _utils from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.types import primitives as _primitives, blobs as _blobs +from flytekit.common.types import blobs as _blobs +from flytekit.common.types import primitives as _primitives from flytekit.common.types.impl import schema as _schema_impl -from flytekit.models import types as _type_models, literals as _literal_models +from flytekit.models import literals as _literal_models +from flytekit.models import types as _type_models from flytekit.sdk import test_utils as _test_utils def test_schema_type(): _schema_impl.SchemaType() _schema_impl.SchemaType([]) - _schema_impl.SchemaType([ - ('a', _primitives.Integer), - ('b', _primitives.String), - ('c', _primitives.Float), - ('d', _primitives.Boolean), - ('e', _primitives.Datetime) - ]) + _schema_impl.SchemaType( + [ + ("a", _primitives.Integer), + ("b", _primitives.String), + ("c", _primitives.Float), + ("d", _primitives.Boolean), + ("e", _primitives.Datetime), + ] + ) with _pytest.raises(ValueError): - _schema_impl.SchemaType({'a': _primitives.Integer}) + _schema_impl.SchemaType({"a": _primitives.Integer}) with _pytest.raises(TypeError): - _schema_impl.SchemaType([('a', _blobs.Blob)]) + _schema_impl.SchemaType([("a", _blobs.Blob)]) with _pytest.raises(ValueError): - _schema_impl.SchemaType([('a', _primitives.Integer, 1)]) + _schema_impl.SchemaType([("a", _primitives.Integer, 1)]) - _schema_impl.SchemaType([('1', _primitives.Integer)]) + _schema_impl.SchemaType([("1", _primitives.Integer)]) with _pytest.raises(TypeError): _schema_impl.SchemaType([(1, _primitives.Integer)]) with _pytest.raises(TypeError): - _schema_impl.SchemaType([('1', [_primitives.Integer])]) + _schema_impl.SchemaType([("1", [_primitives.Integer])]) value_type_tuples = [ - ('abra', _primitives.Integer, [1, 2, 3, 4, 5]), - ('CADABRA', _primitives.Float, [1.0, 2.0, 3.0, 4.0, 5.0]), - ('HoCuS', _primitives.String, ["A", "B", "C", "D", "E"]), - ('Pocus', _primitives.Boolean, [True, False, True, False]), + ("abra", _primitives.Integer, [1, 2, 3, 4, 5]), + ("CADABRA", _primitives.Float, [1.0, 2.0, 3.0, 4.0, 5.0]), + ("HoCuS", _primitives.String, ["A", "B", "C", "D", "E"]), + ("Pocus", _primitives.Boolean, [True, False, True, False]), ( - 'locusts', + "locusts", _primitives.Datetime, [ - _datetime.datetime(day=1, month=1, year=2017, hour=1, minute=1, second=1, microsecond=1) - - _datetime.timedelta(days=i) + _datetime.datetime(day=1, month=1, year=2017, hour=1, minute=1, second=1, microsecond=1) + - _datetime.timedelta(days=i) for i in _six_moves.range(5) - ] - ) + ], + ), ] @@ -70,7 +74,7 @@ def test_simple_read_and_write_with_different_types(value_type_pair): with _test_utils.LocalTestFileSystem() as sandbox: with _utils.AutoDeletingTempDir("test") as t: - a = _schema_impl.Schema.create_at_known_location(t.name, mode='wb', schema_type=schema_type) + a = _schema_impl.Schema.create_at_known_location(t.name, mode="wb", schema_type=schema_type) assert a.local_path is None with a as writer: for _ in _six_moves.range(5): @@ -78,7 +82,7 @@ def test_simple_read_and_write_with_different_types(value_type_pair): assert a.local_path.startswith(sandbox.name) assert a.local_path is None - b = _schema_impl.Schema.create_at_known_location(t.name, mode='rb', schema_type=schema_type) + b = _schema_impl.Schema.create_at_known_location(t.name, mode="rb", schema_type=schema_type) assert b.local_path is None with b as reader: for df in reader.iter_chunks(): @@ -100,36 +104,41 @@ def test_datetime_coercion_explicitly(): """ dt = _datetime.datetime(day=1, month=1, year=2017, hour=1, minute=1, second=1, microsecond=1) values = [(dt,)] - df = _pd.DataFrame.from_records(values, columns=['testname']) - assert df['testname'][0] == dt + df = _pd.DataFrame.from_records(values, columns=["testname"]) + assert df["testname"][0] == dt - with _utils.AutoDeletingTempDir('test') as tmpdir: - tmpfile = tmpdir.get_named_tempfile('repro.parquet') - df.to_parquet(tmpfile, coerce_timestamps='ms', allow_truncated_timestamps=True) + with _utils.AutoDeletingTempDir("test") as tmpdir: + tmpfile = tmpdir.get_named_tempfile("repro.parquet") + df.to_parquet(tmpfile, coerce_timestamps="ms", allow_truncated_timestamps=True) df2 = _pd.read_parquet(tmpfile) dt2 = _datetime.datetime(day=1, month=1, year=2017, hour=1, minute=1, second=1) - assert df2['testname'][0] == dt2 + assert df2["testname"][0] == dt2 def test_datetime_coercion(): values = [ - tuple([_datetime.datetime(day=1, month=1, year=2017, hour=1, minute=1, second=1, microsecond=1) - - _datetime.timedelta(days=x)]) + tuple( + [ + _datetime.datetime(day=1, month=1, year=2017, hour=1, minute=1, second=1, microsecond=1) + - _datetime.timedelta(days=x) + ] + ) for x in _six_moves.range(5) ] - schema_type = _schema_impl.SchemaType(columns=[('testname', _primitives.Datetime)]) + schema_type = _schema_impl.SchemaType(columns=[("testname", _primitives.Datetime)]) with _test_utils.LocalTestFileSystem(): with _utils.AutoDeletingTempDir("test") as t: - a = _schema_impl.Schema.create_at_known_location(t.name, mode='wb', schema_type=schema_type) + a = _schema_impl.Schema.create_at_known_location(t.name, mode="wb", schema_type=schema_type) with a as writer: for _ in _six_moves.range(5): # us to ms coercion segfaults unless we explicitly allow truncation. writer.write( - _pd.DataFrame.from_records(values, columns=['testname']), - coerce_timestamps='ms', - allow_truncated_timestamps=True) + _pd.DataFrame.from_records(values, columns=["testname"]), + coerce_timestamps="ms", + allow_truncated_timestamps=True, + ) # TODO: Uncomment when segfault bug is resolved # with _pytest.raises(Exception): @@ -137,10 +146,10 @@ def test_datetime_coercion(): # _pd.DataFrame.from_records(values, columns=['testname']), # coerce_timestamps='ms') - b = _schema_impl.Schema.create_at_known_location(t.name, mode='wb', schema_type=schema_type) + b = _schema_impl.Schema.create_at_known_location(t.name, mode="wb", schema_type=schema_type) with b as writer: for _ in _six_moves.range(5): - writer.write(_pd.DataFrame.from_records(values, columns=['testname'])) + writer.write(_pd.DataFrame.from_records(values, columns=["testname"])) @_pytest.mark.parametrize("value_type_pair", value_type_tuples) @@ -152,13 +161,12 @@ def test_fetch(value_type_pair): with _utils.AutoDeletingTempDir("test") as tmpdir: for i in _six_moves.range(3): _pd.DataFrame.from_records(values, columns=[column_name]).to_parquet( - tmpdir.get_named_tempfile(str(i).zfill(6)), coerce_timestamps='us') + tmpdir.get_named_tempfile(str(i).zfill(6)), coerce_timestamps="us" + ) with _utils.AutoDeletingTempDir("test2") as local_dir: schema_obj = _schema_impl.Schema.fetch( - tmpdir.name, - local_path=local_dir.get_named_tempfile('schema_test'), - schema_type=schema_type + tmpdir.name, local_path=local_dir.get_named_tempfile("schema_test"), schema_type=schema_type, ) with schema_obj as reader: for df in reader.iter_chunks(): @@ -180,7 +188,8 @@ def test_download(value_type_pair): with _utils.AutoDeletingTempDir("test") as tmpdir: for i in _six_moves.range(3): _pd.DataFrame.from_records(values, columns=[column_name]).to_parquet( - tmpdir.get_named_tempfile(str(i).zfill(6)), coerce_timestamps='us') + tmpdir.get_named_tempfile(str(i).zfill(6)), coerce_timestamps="us" + ) with _utils.AutoDeletingTempDir("test2") as local_dir: schema_obj = _schema_impl.Schema(tmpdir.name, schema_type=schema_type) @@ -217,7 +226,7 @@ def test_hive_queries(monkeypatch): def return_deterministic_uuid(): class FakeUUID4(object): def __init__(self): - self.hex = 'test_uuid' + self.hex = "test_uuid" class Uuid(object): def uuid4(self): @@ -225,22 +234,24 @@ def uuid4(self): return Uuid() - monkeypatch.setattr(_schema_impl, '_uuid', return_deterministic_uuid()) + monkeypatch.setattr(_schema_impl, "_uuid", return_deterministic_uuid()) - all_types = _schema_impl.SchemaType([ - ('a', _primitives.Integer), - ('b', _primitives.String), - ('c', _primitives.Float), - ('d', _primitives.Boolean), - ('e', _primitives.Datetime) - ]) + all_types = _schema_impl.SchemaType( + [ + ("a", _primitives.Integer), + ("b", _primitives.String), + ("c", _primitives.Float), + ("d", _primitives.Boolean), + ("e", _primitives.Datetime), + ] + ) with _test_utils.LocalTestFileSystem(): df, query = _schema_impl.Schema.create_from_hive_query( "SELECT a, b, c, d, e FROM some_place WHERE i = 0", stage_query="CREATE TEMPORARY TABLE some_place AS SELECT * FROM some_place_original", known_location="s3://my_fixed_path/", - schema_type=all_types + schema_type=all_types, ) full_query = """ @@ -274,8 +285,8 @@ def uuid4(self): ) SET LOCATION 's3://my_fixed_path/'; """ query = df.get_write_partition_to_hive_table_query( - 'some_table', - partitions=_collections.OrderedDict([('region', 'SEA'), ('ds', '2017-01-01')])) + "some_table", partitions=_collections.OrderedDict([("region", "SEA"), ("ds", "2017-01-01")]), + ) full_query = " ".join(full_query.split()) query = " ".join(query.split()) assert query == full_query @@ -284,19 +295,18 @@ def uuid4(self): def test_partial_column_read(): with _test_utils.LocalTestFileSystem(): a = _schema_impl.Schema.create_at_any_location( - schema_type=_schema_impl.SchemaType([('a', _primitives.Integer), ('b', _primitives.Integer)]) + schema_type=_schema_impl.SchemaType([("a", _primitives.Integer), ("b", _primitives.Integer)]) ) with a as writer: - writer.write(_pd.DataFrame.from_dict({'a': [1, 2, 3, 4], 'b': [5, 6, 7, 8]})) + writer.write(_pd.DataFrame.from_dict({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]})) b = _schema_impl.Schema.fetch( - a.uri, - schema_type=_schema_impl.SchemaType([('a', _primitives.Integer), ('b', _primitives.Integer)]) + a.uri, schema_type=_schema_impl.SchemaType([("a", _primitives.Integer), ("b", _primitives.Integer)]), ) with b as reader: - df = reader.read(columns=['b']) - assert df.columns.values == ['b'] - assert df['b'].tolist() == [5, 6, 7, 8] + df = reader.read(columns=["b"]) + assert df.columns.values == ["b"] + assert df["b"].tolist() == [5, 6, 7, 8] def test_casting(): @@ -305,48 +315,60 @@ def test_casting(): def test_from_python_std(): with _test_utils.LocalTestFileSystem(): + def single_dataframe(): - df1 = _pd.DataFrame.from_dict({'a': [1, 2, 3, 4], 'b': [5, 6, 7, 8]}) - s = _schema_impl.Schema.from_python_std(t_value=df1, schema_type=_schema_impl.SchemaType( - [('a', _primitives.Integer), ('b', _primitives.Integer)])) + df1 = _pd.DataFrame.from_dict({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]}) + s = _schema_impl.Schema.from_python_std( + t_value=df1, + schema_type=_schema_impl.SchemaType([("a", _primitives.Integer), ("b", _primitives.Integer)]), + ) assert s is not None - n = _schema_impl.Schema.fetch(s.uri, schema_type=_schema_impl.SchemaType( - [('a', _primitives.Integer), ('b', _primitives.Integer)])) + n = _schema_impl.Schema.fetch( + s.uri, schema_type=_schema_impl.SchemaType([("a", _primitives.Integer), ("b", _primitives.Integer)]), + ) with n as reader: df2 = reader.read() assert df2.columns.values.all() == df1.columns.values.all() - assert df2['b'].tolist() == df1['b'].tolist() + assert df2["b"].tolist() == df1["b"].tolist() def list_of_dataframes(): - df1 = _pd.DataFrame.from_dict({'a': [1, 2, 3, 4], 'b': [5, 6, 7, 8]}) - df2 = _pd.DataFrame.from_dict({'a': [9, 10, 11, 12], 'b': [13, 14, 15, 16]}) - s = _schema_impl.Schema.from_python_std(t_value=[df1, df2], schema_type=_schema_impl.SchemaType( - [('a', _primitives.Integer), ('b', _primitives.Integer)])) + df1 = _pd.DataFrame.from_dict({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]}) + df2 = _pd.DataFrame.from_dict({"a": [9, 10, 11, 12], "b": [13, 14, 15, 16]}) + s = _schema_impl.Schema.from_python_std( + t_value=[df1, df2], + schema_type=_schema_impl.SchemaType([("a", _primitives.Integer), ("b", _primitives.Integer)]), + ) assert s is not None - n = _schema_impl.Schema.fetch(s.uri, schema_type=_schema_impl.SchemaType( - [('a', _primitives.Integer), ('b', _primitives.Integer)])) + n = _schema_impl.Schema.fetch( + s.uri, schema_type=_schema_impl.SchemaType([("a", _primitives.Integer), ("b", _primitives.Integer)]), + ) with n as reader: actual = [] for df in reader.iter_chunks(): assert df.columns.values.all() == df1.columns.values.all() - actual.extend(df['b'].tolist()) - b_val = df1['b'].tolist() - b_val.extend(df2['b'].tolist()) + actual.extend(df["b"].tolist()) + b_val = df1["b"].tolist() + b_val.extend(df2["b"].tolist()) assert actual == b_val def mixed_list(): - df1 = _pd.DataFrame.from_dict({'a': [1, 2, 3, 4], 'b': [5, 6, 7, 8]}) + df1 = _pd.DataFrame.from_dict({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]}) df2 = [1, 2, 3] with _pytest.raises(_user_exceptions.FlyteTypeException): - _schema_impl.Schema.from_python_std(t_value=[df1, df2], schema_type=_schema_impl.SchemaType( - [('a', _primitives.Integer), ('b', _primitives.Integer)])) + _schema_impl.Schema.from_python_std( + t_value=[df1, df2], + schema_type=_schema_impl.SchemaType([("a", _primitives.Integer), ("b", _primitives.Integer)]), + ) def empty_list(): - s = _schema_impl.Schema.from_python_std(t_value=[], schema_type=_schema_impl.SchemaType( - [('a', _primitives.Integer), ('b', _primitives.Integer)])) + s = _schema_impl.Schema.from_python_std( + t_value=[], + schema_type=_schema_impl.SchemaType([("a", _primitives.Integer), ("b", _primitives.Integer)]), + ) assert s is not None - n = _schema_impl.Schema.fetch(s.uri, schema_type=_schema_impl.SchemaType( - [('a', _primitives.Integer), ('b', _primitives.Integer)])) + n = _schema_impl.Schema.fetch( + s.uri, schema_type=_schema_impl.SchemaType([("a", _primitives.Integer), ("b", _primitives.Integer)]), + ) with n as reader: df = reader.read() assert df is None @@ -360,40 +382,22 @@ def empty_list(): def test_promote_from_model_schema_type(): m = _type_models.SchemaType( [ - _type_models.SchemaType.SchemaColumn( - "a", - _type_models.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN - ), - _type_models.SchemaType.SchemaColumn( - "b", - _type_models.SchemaType.SchemaColumn.SchemaColumnType.DATETIME - ), - _type_models.SchemaType.SchemaColumn( - "c", - _type_models.SchemaType.SchemaColumn.SchemaColumnType.DURATION - ), - _type_models.SchemaType.SchemaColumn( - "d", - _type_models.SchemaType.SchemaColumn.SchemaColumnType.FLOAT - ), - _type_models.SchemaType.SchemaColumn( - "e", - _type_models.SchemaType.SchemaColumn.SchemaColumnType.INTEGER - ), - _type_models.SchemaType.SchemaColumn( - "f", - _type_models.SchemaType.SchemaColumn.SchemaColumnType.STRING - ), + _type_models.SchemaType.SchemaColumn("a", _type_models.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN), + _type_models.SchemaType.SchemaColumn("b", _type_models.SchemaType.SchemaColumn.SchemaColumnType.DATETIME), + _type_models.SchemaType.SchemaColumn("c", _type_models.SchemaType.SchemaColumn.SchemaColumnType.DURATION), + _type_models.SchemaType.SchemaColumn("d", _type_models.SchemaType.SchemaColumn.SchemaColumnType.FLOAT), + _type_models.SchemaType.SchemaColumn("e", _type_models.SchemaType.SchemaColumn.SchemaColumnType.INTEGER), + _type_models.SchemaType.SchemaColumn("f", _type_models.SchemaType.SchemaColumn.SchemaColumnType.STRING), ] ) s = _schema_impl.SchemaType.promote_from_model(m) assert s.columns == m.columns - assert s.sdk_columns['a'].to_flyte_literal_type() == _primitives.Boolean.to_flyte_literal_type() - assert s.sdk_columns['b'].to_flyte_literal_type() == _primitives.Datetime.to_flyte_literal_type() - assert s.sdk_columns['c'].to_flyte_literal_type() == _primitives.Timedelta.to_flyte_literal_type() - assert s.sdk_columns['d'].to_flyte_literal_type() == _primitives.Float.to_flyte_literal_type() - assert s.sdk_columns['e'].to_flyte_literal_type() == _primitives.Integer.to_flyte_literal_type() - assert s.sdk_columns['f'].to_flyte_literal_type() == _primitives.String.to_flyte_literal_type() + assert s.sdk_columns["a"].to_flyte_literal_type() == _primitives.Boolean.to_flyte_literal_type() + assert s.sdk_columns["b"].to_flyte_literal_type() == _primitives.Datetime.to_flyte_literal_type() + assert s.sdk_columns["c"].to_flyte_literal_type() == _primitives.Timedelta.to_flyte_literal_type() + assert s.sdk_columns["d"].to_flyte_literal_type() == _primitives.Float.to_flyte_literal_type() + assert s.sdk_columns["e"].to_flyte_literal_type() == _primitives.Integer.to_flyte_literal_type() + assert s.sdk_columns["f"].to_flyte_literal_type() == _primitives.String.to_flyte_literal_type() assert s == m @@ -403,157 +407,141 @@ def test_promote_from_model_schema(): _type_models.SchemaType( [ _type_models.SchemaType.SchemaColumn( - "a", - _type_models.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN - ), - _type_models.SchemaType.SchemaColumn( - "b", - _type_models.SchemaType.SchemaColumn.SchemaColumnType.DATETIME + "a", _type_models.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN ), _type_models.SchemaType.SchemaColumn( - "c", - _type_models.SchemaType.SchemaColumn.SchemaColumnType.DURATION + "b", _type_models.SchemaType.SchemaColumn.SchemaColumnType.DATETIME ), _type_models.SchemaType.SchemaColumn( - "d", - _type_models.SchemaType.SchemaColumn.SchemaColumnType.FLOAT + "c", _type_models.SchemaType.SchemaColumn.SchemaColumnType.DURATION ), + _type_models.SchemaType.SchemaColumn("d", _type_models.SchemaType.SchemaColumn.SchemaColumnType.FLOAT), _type_models.SchemaType.SchemaColumn( - "e", - _type_models.SchemaType.SchemaColumn.SchemaColumnType.INTEGER - ), - _type_models.SchemaType.SchemaColumn( - "f", - _type_models.SchemaType.SchemaColumn.SchemaColumnType.STRING + "e", _type_models.SchemaType.SchemaColumn.SchemaColumnType.INTEGER ), + _type_models.SchemaType.SchemaColumn("f", _type_models.SchemaType.SchemaColumn.SchemaColumnType.STRING), ] - ) + ), ) s = _schema_impl.Schema.promote_from_model(m) assert s.uri == "s3://some/place/" - assert s.type.sdk_columns['a'].to_flyte_literal_type() == _primitives.Boolean.to_flyte_literal_type() - assert s.type.sdk_columns['b'].to_flyte_literal_type() == _primitives.Datetime.to_flyte_literal_type() - assert s.type.sdk_columns['c'].to_flyte_literal_type() == _primitives.Timedelta.to_flyte_literal_type() - assert s.type.sdk_columns['d'].to_flyte_literal_type() == _primitives.Float.to_flyte_literal_type() - assert s.type.sdk_columns['e'].to_flyte_literal_type() == _primitives.Integer.to_flyte_literal_type() - assert s.type.sdk_columns['f'].to_flyte_literal_type() == _primitives.String.to_flyte_literal_type() + assert s.type.sdk_columns["a"].to_flyte_literal_type() == _primitives.Boolean.to_flyte_literal_type() + assert s.type.sdk_columns["b"].to_flyte_literal_type() == _primitives.Datetime.to_flyte_literal_type() + assert s.type.sdk_columns["c"].to_flyte_literal_type() == _primitives.Timedelta.to_flyte_literal_type() + assert s.type.sdk_columns["d"].to_flyte_literal_type() == _primitives.Float.to_flyte_literal_type() + assert s.type.sdk_columns["e"].to_flyte_literal_type() == _primitives.Integer.to_flyte_literal_type() + assert s.type.sdk_columns["f"].to_flyte_literal_type() == _primitives.String.to_flyte_literal_type() assert s == m def test_create_at_known_location(): with _test_utils.LocalTestFileSystem(): - with _utils.AutoDeletingTempDir('test') as wd: + with _utils.AutoDeletingTempDir("test") as wd: b = _schema_impl.Schema.create_at_known_location(wd.name, schema_type=_schema_impl.SchemaType()) assert b.local_path is None assert b.remote_location == wd.name + "/" - assert b.mode == 'wb' + assert b.mode == "wb" with b as w: - w.write(_pd.DataFrame.from_dict({'a': [1, 2, 3, 4], 'b': [5, 6, 7, 8]})) + w.write(_pd.DataFrame.from_dict({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]})) df = _pd.read_parquet(_os.path.join(wd.name, "000000")) - assert list(df['a']) == [1, 2, 3, 4] - assert list(df['b']) == [5, 6, 7, 8] + assert list(df["a"]) == [1, 2, 3, 4] + assert list(df["b"]) == [5, 6, 7, 8] def test_generic_schema_read(): with _test_utils.LocalTestFileSystem(): a = _schema_impl.Schema.create_at_any_location( - schema_type=_schema_impl.SchemaType([('a', _primitives.Integer), ('b', _primitives.Integer)]) + schema_type=_schema_impl.SchemaType([("a", _primitives.Integer), ("b", _primitives.Integer)]) ) with a as writer: - writer.write(_pd.DataFrame.from_dict({'a': [1, 2, 3, 4], 'b': [5, 6, 7, 8]})) + writer.write(_pd.DataFrame.from_dict({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]})) - b = _schema_impl.Schema.fetch( - a.remote_prefix, - schema_type=_schema_impl.SchemaType([])) + b = _schema_impl.Schema.fetch(a.remote_prefix, schema_type=_schema_impl.SchemaType([])) with b as reader: df = reader.read() - assert df.columns.values.tolist() == ['a', 'b'] - assert df['a'].tolist() == [1, 2, 3, 4] - assert df['b'].tolist() == [5, 6, 7, 8] + assert df.columns.values.tolist() == ["a", "b"] + assert df["a"].tolist() == [1, 2, 3, 4] + assert df["b"].tolist() == [5, 6, 7, 8] def test_extra_schema_read(): with _test_utils.LocalTestFileSystem(): a = _schema_impl.Schema.create_at_any_location( - schema_type=_schema_impl.SchemaType([('a', _primitives.Integer), ('b', _primitives.Integer)]) + schema_type=_schema_impl.SchemaType([("a", _primitives.Integer), ("b", _primitives.Integer)]) ) with a as writer: - writer.write(_pd.DataFrame.from_dict({'a': [1, 2, 3, 4], 'b': [5, 6, 7, 8]})) + writer.write(_pd.DataFrame.from_dict({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]})) b = _schema_impl.Schema.fetch( - a.remote_prefix, - schema_type=_schema_impl.SchemaType([('a', _primitives.Integer)])) + a.remote_prefix, schema_type=_schema_impl.SchemaType([("a", _primitives.Integer)]), + ) with b as reader: df = reader.read(concat=True, truncate_extra_columns=False) - assert df.columns.values.tolist() == ['a', 'b'] - assert df['a'].tolist() == [1, 2, 3, 4] - assert df['b'].tolist() == [5, 6, 7, 8] + assert df.columns.values.tolist() == ["a", "b"] + assert df["a"].tolist() == [1, 2, 3, 4] + assert df["b"].tolist() == [5, 6, 7, 8] with b as reader: df = reader.read(concat=True) - assert df.columns.values.tolist() == ['a'] - assert df['a'].tolist() == [1, 2, 3, 4] + assert df.columns.values.tolist() == ["a"] + assert df["a"].tolist() == [1, 2, 3, 4] def test_normal_schema_read_with_fastparquet(): with _test_utils.LocalTestFileSystem(): a = _schema_impl.Schema.create_at_any_location( - schema_type=_schema_impl.SchemaType([('a', _primitives.Integer), ('b', _primitives.Boolean)]) + schema_type=_schema_impl.SchemaType([("a", _primitives.Integer), ("b", _primitives.Boolean)]) ) with a as writer: - writer.write(_pd.DataFrame.from_dict({'a': [1, 2, 3, 4], 'b': [False, True, True, False]})) + writer.write(_pd.DataFrame.from_dict({"a": [1, 2, 3, 4], "b": [False, True, True, False]})) import os as _os - original_engine = _os.getenv('PARQUET_ENGINE') - _os.environ['PARQUET_ENGINE'] = 'fastparquet' - b = _schema_impl.Schema.fetch( - a.remote_prefix, - schema_type=_schema_impl.SchemaType([])) + original_engine = _os.getenv("PARQUET_ENGINE") + _os.environ["PARQUET_ENGINE"] = "fastparquet" + + b = _schema_impl.Schema.fetch(a.remote_prefix, schema_type=_schema_impl.SchemaType([])) with b as reader: df = reader.read() - assert df['a'].tolist() == [1, 2, 3, 4] - assert _pd.api.types.is_bool_dtype(df.dtypes['b']) - assert df['b'].tolist() == [False, True, True, False] + assert df["a"].tolist() == [1, 2, 3, 4] + assert _pd.api.types.is_bool_dtype(df.dtypes["b"]) + assert df["b"].tolist() == [False, True, True, False] if original_engine is None: - del _os.environ['PARQUET_ENGINE'] + del _os.environ["PARQUET_ENGINE"] else: - _os.environ['PARQUET_ENGINE'] = original_engine + _os.environ["PARQUET_ENGINE"] = original_engine def test_schema_read_consistency_between_two_engines(): with _test_utils.LocalTestFileSystem(): a = _schema_impl.Schema.create_at_any_location( - schema_type=_schema_impl.SchemaType([('a', _primitives.Integer), ('b', _primitives.Boolean)]) + schema_type=_schema_impl.SchemaType([("a", _primitives.Integer), ("b", _primitives.Boolean)]) ) with a as writer: - writer.write(_pd.DataFrame.from_dict({'a': [1, 2, 3, 4], 'b': [True, True, True, False]})) + writer.write(_pd.DataFrame.from_dict({"a": [1, 2, 3, 4], "b": [True, True, True, False]})) import os as _os - original_engine = _os.getenv('PARQUET_ENGINE') - _os.environ['PARQUET_ENGINE'] = 'fastparquet' - b = _schema_impl.Schema.fetch( - a.remote_prefix, - schema_type=_schema_impl.SchemaType([])) + original_engine = _os.getenv("PARQUET_ENGINE") + _os.environ["PARQUET_ENGINE"] = "fastparquet" + + b = _schema_impl.Schema.fetch(a.remote_prefix, schema_type=_schema_impl.SchemaType([])) with b as b_reader: b_df = b_reader.read() - _os.environ['PARQUET_ENGINE'] = 'pyarrow' + _os.environ["PARQUET_ENGINE"] = "pyarrow" - c = _schema_impl.Schema.fetch( - a.remote_prefix, - schema_type=_schema_impl.SchemaType([])) + c = _schema_impl.Schema.fetch(a.remote_prefix, schema_type=_schema_impl.SchemaType([])) with c as c_reader: c_df = c_reader.read() assert b_df.equals(c_df) if original_engine is None: - del _os.environ['PARQUET_ENGINE'] + del _os.environ["PARQUET_ENGINE"] else: - _os.environ['PARQUET_ENGINE'] = original_engine + _os.environ["PARQUET_ENGINE"] = original_engine diff --git a/tests/flytekit/unit/common_tests/types/test_blobs.py b/tests/flytekit/unit/common_tests/types/test_blobs.py index 85986535e3..f5047da966 100644 --- a/tests/flytekit/unit/common_tests/types/test_blobs.py +++ b/tests/flytekit/unit/common_tests/types/test_blobs.py @@ -1,4 +1,5 @@ from __future__ import absolute_import + from flytekit.common.types import blobs from flytekit.common.types.impl import blobs as blob_impl from flytekit.models import literals as _literal_models @@ -38,12 +39,9 @@ def test_blob_promote_from_model(): scalar=_literal_models.Scalar( blob=_literal_models.Blob( _literal_models.BlobMetadata( - _core_types.BlobType( - format="f", - dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE - ) + _core_types.BlobType(format="f", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE,) ), - "some/path" + "some/path", ) ) ) diff --git a/tests/flytekit/unit/common_tests/types/test_containers.py b/tests/flytekit/unit/common_tests/types/test_containers.py index d44ce06274..cdba800058 100644 --- a/tests/flytekit/unit/common_tests/types/test_containers.py +++ b/tests/flytekit/unit/common_tests/types/test_containers.py @@ -1,11 +1,12 @@ from __future__ import absolute_import import pytest +from six.moves import range as _range from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.types import primitives, containers -from flytekit.models import types as literal_types, literals -from six.moves import range as _range +from flytekit.common.types import containers, primitives +from flytekit.models import literals +from flytekit.models import types as literal_types def test_list(): @@ -24,14 +25,14 @@ def test_list(): assert list_value.collection.literals[2].scalar.primitive.integer == 3 assert list_value.collection.literals[3].scalar.primitive.integer == 4 - obj2 = list_type.from_string('[1, 2, 3,4]') + obj2 = list_type.from_string("[1, 2, 3,4]") assert obj2 == list_value with pytest.raises(_user_exceptions.FlyteTypeException): - list_type.from_python_std(['a', 'b', 'c', 'd']) + list_type.from_python_std(["a", "b", "c", "d"]) with pytest.raises(_user_exceptions.FlyteTypeException): - list_type.from_python_std([1, 2, 3, 'abc']) + list_type.from_python_std([1, 2, 3, "abc"]) with pytest.raises(_user_exceptions.FlyteTypeException): list_type.from_python_std(1) @@ -43,10 +44,10 @@ def test_list(): list_type.from_string('["fdsa"]') with pytest.raises(_user_exceptions.FlyteTypeException): - list_type.from_string('[1, 2, 3, []]') + list_type.from_string("[1, 2, 3, []]") with pytest.raises(_user_exceptions.FlyteTypeException): - list_type.from_string('\'["not list json"]\'') + list_type.from_string("'[\"not list json\"]'") with pytest.raises(_user_exceptions.FlyteTypeException): list_type.from_string('["unclosed","list"') @@ -66,7 +67,7 @@ def test_string_list(): def test_empty_parsing(): list_type = containers.List(primitives.String) - obj = list_type.from_string('[]') + obj = list_type.from_string("[]") assert len(obj) == 0 # The String primitive type does not allow lists or maps to be converted @@ -103,7 +104,7 @@ def test_nested_list(): assert len(list_value.collection.literals[2].collection.literals) == 0 - obj = list_type.from_string('[[1, 2, 3], [4, 5, 6]]') + obj = list_type.from_string("[[1, 2, 3], [4, 5, 6]]") assert len(obj) == 2 assert len(obj.collection.literals[0]) == 3 @@ -112,44 +113,48 @@ def test_reprs(): list_type = containers.List(primitives.Integer) obj = list_type.from_python_std(list(_range(3))) assert obj.short_string() == "List(len=3, [Integer(0), Integer(1), Integer(2)])" - assert obj.verbose_string() == \ - "List(\n" \ - "\tlen=3,\n" \ - "\t[\n" \ - "\t\tInteger(0),\n" \ - "\t\tInteger(1),\n" \ - "\t\tInteger(2)\n" \ - "\t]\n" \ + assert ( + obj.verbose_string() == "List(\n" + "\tlen=3,\n" + "\t[\n" + "\t\tInteger(0),\n" + "\t\tInteger(1),\n" + "\t\tInteger(2)\n" + "\t]\n" ")" + ) nested_list_type = containers.List(containers.List(primitives.Integer)) nested_obj = nested_list_type.from_python_std([list(_range(3)), list(_range(3))]) - assert nested_obj.short_string() == \ - "List>(len=2, [List(len=3, [Integer(0), Integer(1), Integer(2)]), " \ + assert ( + nested_obj.short_string() + == "List>(len=2, [List(len=3, [Integer(0), Integer(1), Integer(2)]), " "List(len=3, [Integer(0), Integer(1), Integer(2)])])" - assert nested_obj.verbose_string() == \ - "List>(\n" \ - "\tlen=2,\n" \ - "\t[\n" \ - "\t\tList(\n" \ - "\t\t\tlen=3,\n" \ - "\t\t\t[\n" \ - "\t\t\t\tInteger(0),\n" \ - "\t\t\t\tInteger(1),\n" \ - "\t\t\t\tInteger(2)\n" \ - "\t\t\t]\n" \ - "\t\t),\n" \ - "\t\tList(\n" \ - "\t\t\tlen=3,\n" \ - "\t\t\t[\n" \ - "\t\t\t\tInteger(0),\n" \ - "\t\t\t\tInteger(1),\n" \ - "\t\t\t\tInteger(2)\n" \ - "\t\t\t]\n" \ - "\t\t)\n" \ - "\t]\n" \ + ) + assert ( + nested_obj.verbose_string() == "List>(\n" + "\tlen=2,\n" + "\t[\n" + "\t\tList(\n" + "\t\t\tlen=3,\n" + "\t\t\t[\n" + "\t\t\t\tInteger(0),\n" + "\t\t\t\tInteger(1),\n" + "\t\t\t\tInteger(2)\n" + "\t\t\t]\n" + "\t\t),\n" + "\t\tList(\n" + "\t\t\tlen=3,\n" + "\t\t\t[\n" + "\t\t\t\tInteger(0),\n" + "\t\t\t\tInteger(1),\n" + "\t\t\t\tInteger(2)\n" + "\t\t\t]\n" + "\t\t)\n" + "\t]\n" ")" + ) def test_model_promotion(): @@ -159,7 +164,7 @@ def test_model_promotion(): literals=[ literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=0))), literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=1))), - literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=2))) + literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=2))), ] ) ) diff --git a/tests/flytekit/unit/common_tests/types/test_helpers.py b/tests/flytekit/unit/common_tests/types/test_helpers.py index c85402643b..875a0bc40c 100644 --- a/tests/flytekit/unit/common_tests/types/test_helpers.py +++ b/tests/flytekit/unit/common_tests/types/test_helpers.py @@ -1,6 +1,9 @@ from __future__ import absolute_import -from flytekit.common.types import helpers as _type_helpers, base_sdk_types as _base_sdk_types -from flytekit.models import literals as _literals, types as _model_types + +from flytekit.common.types import base_sdk_types as _base_sdk_types +from flytekit.common.types import helpers as _type_helpers +from flytekit.models import literals as _literals +from flytekit.models import types as _model_types from flytekit.sdk import types as _sdk_types @@ -30,20 +33,17 @@ def test_infer_sdk_type_from_literal(): def test_get_sdk_value_from_literal(): - o = _type_helpers.get_sdk_value_from_literal( - _literals.Literal(scalar=_literals.Scalar(none_type=_literals.Void())) - ) + o = _type_helpers.get_sdk_value_from_literal(_literals.Literal(scalar=_literals.Scalar(none_type=_literals.Void()))) assert o.to_python_std() is None o = _type_helpers.get_sdk_value_from_literal( - _literals.Literal(scalar=_literals.Scalar(none_type=_literals.Void())), - sdk_type=_sdk_types.Types.Integer + _literals.Literal(scalar=_literals.Scalar(none_type=_literals.Void())), sdk_type=_sdk_types.Types.Integer, ) assert o.to_python_std() is None o = _type_helpers.get_sdk_value_from_literal( _literals.Literal(scalar=_literals.Scalar(primitive=_literals.Primitive(integer=1))), - sdk_type=_sdk_types.Types.Integer + sdk_type=_sdk_types.Types.Integer, ) assert o.to_python_std() == 1 diff --git a/tests/flytekit/unit/common_tests/types/test_primitives.py b/tests/flytekit/unit/common_tests/types/test_primitives.py index ea80093b30..9b7d32d898 100644 --- a/tests/flytekit/unit/common_tests/types/test_primitives.py +++ b/tests/flytekit/unit/common_tests/types/test_primitives.py @@ -1,10 +1,13 @@ from __future__ import absolute_import -from flytekit.common.exceptions import user as user_exceptions -from flytekit.common.types import primitives, base_sdk_types -from flytekit.models import types as literal_types -from dateutil import tz + import datetime + import pytest +from dateutil import tz + +from flytekit.common.exceptions import user as user_exceptions +from flytekit.common.types import base_sdk_types, primitives +from flytekit.models import types as literal_types def test_integer(): @@ -16,7 +19,14 @@ def test_integer(): assert obj.to_python_std() == 1 assert primitives.Integer.from_flyte_idl(obj.to_flyte_idl()) == obj - for val in [1.0, 'abc', True, False, datetime.datetime.now(), datetime.timedelta(seconds=1)]: + for val in [ + 1.0, + "abc", + True, + False, + datetime.datetime.now(), + datetime.timedelta(seconds=1), + ]: with pytest.raises(user_exceptions.FlyteTypeException): primitives.Integer.from_python_std(val) @@ -26,13 +36,13 @@ def test_integer(): # Test string parsing with pytest.raises(user_exceptions.FlyteTypeException): - primitives.Integer.from_string('books') - obj = primitives.Integer.from_string('299792458') + primitives.Integer.from_string("books") + obj = primitives.Integer.from_string("299792458") assert obj.to_python_std() == 299792458 assert primitives.Integer.from_flyte_idl(obj.to_flyte_idl()) == obj - assert obj.short_string() == 'Integer(299792458)' - assert obj.verbose_string() == 'Integer(299792458)' + assert obj.short_string() == "Integer(299792458)" + assert obj.verbose_string() == "Integer(299792458)" def test_float(): @@ -44,7 +54,14 @@ def test_float(): assert obj.to_python_std() == 1.0 assert primitives.Float.from_flyte_idl(obj.to_flyte_idl()) == obj - for val in [1, 'abc', True, False, datetime.datetime.now(), datetime.timedelta(seconds=1)]: + for val in [ + 1, + "abc", + True, + False, + datetime.datetime.now(), + datetime.timedelta(seconds=1), + ]: with pytest.raises(user_exceptions.FlyteTypeException): primitives.Float.from_python_std(val) @@ -54,13 +71,13 @@ def test_float(): # Test string parsing with pytest.raises(user_exceptions.FlyteTypeException): - primitives.Float.from_string('lightning') - obj = primitives.Float.from_string('2.71828') + primitives.Float.from_string("lightning") + obj = primitives.Float.from_string("2.71828") assert obj.to_python_std() == 2.71828 assert primitives.Float.from_flyte_idl(obj.to_flyte_idl()) == obj - assert obj.short_string() == 'Float(2.71828)' - assert obj.verbose_string() == 'Float(2.71828)' + assert obj.short_string() == "Float(2.71828)" + assert obj.verbose_string() == "Float(2.71828)" def test_boolean(): @@ -72,7 +89,7 @@ def test_boolean(): assert obj.to_python_std() is True assert primitives.Boolean.from_flyte_idl(obj.to_flyte_idl()) == obj - for val in [1, 1.0, 'abc', datetime.datetime.now(), datetime.timedelta(seconds=1)]: + for val in [1, 1.0, "abc", datetime.datetime.now(), datetime.timedelta(seconds=1)]: with pytest.raises(user_exceptions.FlyteTypeException): primitives.Boolean.from_python_std(val) @@ -82,24 +99,24 @@ def test_boolean(): # Test string parsing with pytest.raises(user_exceptions.FlyteTypeException): - primitives.Boolean.from_string('lightning') - obj = primitives.Boolean.from_string('false') + primitives.Boolean.from_string("lightning") + obj = primitives.Boolean.from_string("false") assert not obj.to_python_std() assert primitives.Boolean.from_flyte_idl(obj.to_flyte_idl()) == obj - obj = primitives.Boolean.from_string('False') + obj = primitives.Boolean.from_string("False") assert not obj.to_python_std() - obj = primitives.Boolean.from_string('0') + obj = primitives.Boolean.from_string("0") assert not obj.to_python_std() - obj = primitives.Boolean.from_string('true') + obj = primitives.Boolean.from_string("true") assert obj.to_python_std() - obj = primitives.Boolean.from_string('True') + obj = primitives.Boolean.from_string("True") assert obj.to_python_std() - obj = primitives.Boolean.from_string('1') + obj = primitives.Boolean.from_string("1") assert obj.to_python_std() assert primitives.Boolean.from_flyte_idl(obj.to_flyte_idl()) == obj - assert obj.short_string() == 'Boolean(True)' - assert obj.verbose_string() == 'Boolean(True)' + assert obj.short_string() == "Boolean(True)" + assert obj.verbose_string() == "Boolean(True)" def test_string(): @@ -107,11 +124,18 @@ def test_string(): assert primitives.String.to_flyte_literal_type().simple == literal_types.SimpleType.STRING # Test value behavior - obj = primitives.String.from_python_std('abc') - assert obj.to_python_std() == 'abc' + obj = primitives.String.from_python_std("abc") + assert obj.to_python_std() == "abc" assert primitives.String.from_flyte_idl(obj.to_flyte_idl()) == obj - for val in [1, 1.0, True, False, datetime.datetime.now(), datetime.timedelta(seconds=1)]: + for val in [ + 1, + 1.0, + True, + False, + datetime.datetime.now(), + datetime.timedelta(seconds=1), + ]: with pytest.raises(user_exceptions.FlyteTypeException): primitives.String.from_python_std(val) @@ -120,7 +144,7 @@ def test_string(): assert primitives.String.from_flyte_idl(obj.to_flyte_idl()) == obj # Test string parsing - my_string = 'this is a string' + my_string = "this is a string" obj = primitives.String.from_string(my_string) assert obj.to_python_std() == my_string assert primitives.String.from_flyte_idl(obj.to_flyte_idl()) == obj @@ -162,7 +186,7 @@ def test_datetime(): with pytest.raises(user_exceptions.FlyteValueException): primitives.Datetime.from_python_std(datetime.datetime.now()) - for val in [1, 1.0, 'abc', True, False, datetime.timedelta(seconds=1)]: + for val in [1, 1.0, "abc", True, False, datetime.timedelta(seconds=1)]: with pytest.raises(user_exceptions.FlyteTypeException): primitives.Datetime.from_python_std(val) @@ -172,8 +196,8 @@ def test_datetime(): # Test string parsing with pytest.raises(user_exceptions.FlyteTypeException): - primitives.Datetime.from_string('not a real date') - obj = primitives.Datetime.from_string('2018-05-15 4:32pm UTC') + primitives.Datetime.from_string("not a real date") + obj = primitives.Datetime.from_string("2018-05-15 4:32pm UTC") test_dt = datetime.datetime(2018, 5, 15, 16, 32, 0, 0, UTC()) assert obj.short_string() == "Datetime(2018-05-15 16:32:00+00:00)" assert obj.verbose_string() == "Datetime(2018-05-15 16:32:00+00:00)" @@ -190,7 +214,7 @@ def test_timedelta(): assert obj.to_python_std() == datetime.timedelta(seconds=1) assert primitives.Timedelta.from_flyte_idl(obj.to_flyte_idl()) == obj - for val in [1.0, 'abc', True, False, datetime.datetime.now()]: + for val in [1.0, "abc", True, False, datetime.datetime.now()]: with pytest.raises(user_exceptions.FlyteTypeException): primitives.Timedelta.from_python_std(val) @@ -200,8 +224,8 @@ def test_timedelta(): # Test string parsing with pytest.raises(user_exceptions.FlyteTypeException): - primitives.Timedelta.from_string('not a real duration') - obj = primitives.Timedelta.from_string('15 hours, 1.1 second') + primitives.Timedelta.from_string("not a real duration") + obj = primitives.Timedelta.from_string("15 hours, 1.1 second") test_d = datetime.timedelta(hours=15, seconds=1, milliseconds=100) assert obj.short_string() == "Timedelta(15:00:01.100000)" assert obj.verbose_string() == "Timedelta(15:00:01.100000)" @@ -215,7 +239,16 @@ def test_void(): base_sdk_types.Void.to_flyte_literal_type() # Test value behavior - for val in [1, 1.0, 'abc', True, False, datetime.datetime.now(), datetime.timedelta(seconds=1), None]: + for val in [ + 1, + 1.0, + "abc", + True, + False, + datetime.datetime.now(), + datetime.timedelta(seconds=1), + None, + ]: assert base_sdk_types.Void.from_python_std(val).to_python_std() is None obj = base_sdk_types.Void() @@ -230,12 +263,19 @@ def test_generic(): assert primitives.Generic.to_flyte_literal_type().simple == literal_types.SimpleType.STRUCT # Test value behavior - d = {'a': [1, 2, 3], 'b': 'abc', 'c': 1, 'd': {'a': 1}} + d = {"a": [1, 2, 3], "b": "abc", "c": 1, "d": {"a": 1}} obj = primitives.Generic.from_python_std(d) assert obj.to_python_std() == d assert primitives.Generic.from_flyte_idl(obj.to_flyte_idl()) == obj - for val in [1.0, 'abc', True, False, datetime.datetime.now(), datetime.timedelta(seconds=1)]: + for val in [ + 1.0, + "abc", + True, + False, + datetime.datetime.now(), + datetime.timedelta(seconds=1), + ]: with pytest.raises(user_exceptions.FlyteTypeException): primitives.Generic.from_python_std(val) @@ -245,7 +285,7 @@ def test_generic(): # Test string parsing with pytest.raises(user_exceptions.FlyteValueException): - primitives.Generic.from_string('1') + primitives.Generic.from_string("1") obj = primitives.Generic.from_string('{"a": 1.0}') assert obj.to_python_std() == {"a": 1.0} assert primitives.Generic.from_flyte_idl(obj.to_flyte_idl()) == obj diff --git a/tests/flytekit/unit/common_tests/types/test_proto.py b/tests/flytekit/unit/common_tests/types/test_proto.py index 67ba202fa7..9fab12fca6 100644 --- a/tests/flytekit/unit/common_tests/types/test_proto.py +++ b/tests/flytekit/unit/common_tests/types/test_proto.py @@ -1,10 +1,13 @@ from __future__ import absolute_import + +import base64 as _base64 + +import pytest as _pytest +from flyteidl.core import errors_pb2 as _errors_pb2 + from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.types import proto as _proto -from flyteidl.core import errors_pb2 as _errors_pb2 from flytekit.models import types as _type_models -import base64 as _base64 -import pytest as _pytest def test_wrong_type(): @@ -16,8 +19,10 @@ def test_proto_to_literal_type(): proto_type = _proto.create_protobuf(_errors_pb2.ContainerError) assert proto_type.to_flyte_literal_type().simple == _type_models.SimpleType.BINARY assert len(proto_type.to_flyte_literal_type().metadata) == 1 - assert proto_type.to_flyte_literal_type().metadata[_proto.Protobuf.PB_FIELD_KEY] == \ - "flyteidl.core.errors_pb2.ContainerError" + assert ( + proto_type.to_flyte_literal_type().metadata[_proto.Protobuf.PB_FIELD_KEY] + == "flyteidl.core.errors_pb2.ContainerError" + ) def test_proto(): diff --git a/tests/flytekit/unit/common_tests/types/test_schema.py b/tests/flytekit/unit/common_tests/types/test_schema.py index b2b545325f..512e814fe6 100644 --- a/tests/flytekit/unit/common_tests/types/test_schema.py +++ b/tests/flytekit/unit/common_tests/types/test_schema.py @@ -1,16 +1,16 @@ from __future__ import absolute_import -from flytekit.common.types import schema, primitives + +from flytekit.common.types import primitives, schema from flytekit.common.types.impl import schema as schema_impl from flytekit.sdk import test_utils - _ALL_COLUMN_TYPES = [ - ('a', primitives.Integer), - ('b', primitives.String), - ('c', primitives.Float), - ('d', primitives.Datetime), - ('e', primitives.Timedelta), - ('f', primitives.Boolean), + ("a", primitives.Integer), + ("b", primitives.String), + ("c", primitives.Float), + ("d", primitives.Datetime), + ("e", primitives.Timedelta), + ("f", primitives.Boolean), ] @@ -59,9 +59,9 @@ def test_casting(): class MyDateTime(primitives.Datetime): ... - with test_utils.LocalTestFileSystem() as t: - test_columns_1 = [('altered', MyDateTime)] - test_columns_2 = [('altered', primitives.Datetime)] + with test_utils.LocalTestFileSystem(): + test_columns_1 = [("altered", MyDateTime)] + test_columns_2 = [("altered", primitives.Datetime)] instantiator_1 = schema.schema_instantiator(test_columns_1) a = instantiator_1() diff --git a/tests/flytekit/unit/configuration/conftest.py b/tests/flytekit/unit/configuration/conftest.py index ba1fe76457..9feb0c8738 100644 --- a/tests/flytekit/unit/configuration/conftest.py +++ b/tests/flytekit/unit/configuration/conftest.py @@ -1,8 +1,11 @@ from __future__ import absolute_import -from flytekit.configuration import set_flyte_config_file as _set_config -import pytest as _pytest + import os as _os +import pytest as _pytest + +from flytekit.configuration import set_flyte_config_file as _set_config + @_pytest.fixture(scope="function", autouse=True) def clear_configs(): diff --git a/tests/flytekit/unit/configuration/test_common.py b/tests/flytekit/unit/configuration/test_common.py index 273c3cba1e..8210753e92 100644 --- a/tests/flytekit/unit/configuration/test_common.py +++ b/tests/flytekit/unit/configuration/test_common.py @@ -1,34 +1,36 @@ from __future__ import absolute_import -from flytekit.configuration import set_flyte_config_file, common + import os + import pytest +from flytekit.configuration import common, set_flyte_config_file + def test_file_loader_bad(): - set_flyte_config_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'configs/bad.config')) + set_flyte_config_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/bad.config")) with pytest.raises(Exception): - common.CONFIGURATION_SINGLETON.get_string('a', 'b') + common.CONFIGURATION_SINGLETON.get_string("a", "b") def test_file_loader_good(): - set_flyte_config_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'configs/good.config')) - assert common.CONFIGURATION_SINGLETON.get_string('sdk', 'workflow_packages') == \ - 'this.module,that.module' - assert common.CONFIGURATION_SINGLETON.get_string('auth', 'assumable_iam_role') == 'some_role' + set_flyte_config_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/good.config")) + assert common.CONFIGURATION_SINGLETON.get_string("sdk", "workflow_packages") == "this.module,that.module" + assert common.CONFIGURATION_SINGLETON.get_string("auth", "assumable_iam_role") == "some_role" def test_env_var_precedence_string(): - set_flyte_config_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'configs/good.config')) + set_flyte_config_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/good.config")) - assert common.FlyteIntegerConfigurationEntry('madeup', 'int_value').get() == 3 - assert common.FlyteStringConfigurationEntry('madeup', 'string_value').get() == 'abc' + assert common.FlyteIntegerConfigurationEntry("madeup", "int_value").get() == 3 + assert common.FlyteStringConfigurationEntry("madeup", "string_value").get() == "abc" old_environ = dict(os.environ) try: - os.environ['FLYTE_MADEUP_INT_VALUE'] = '10' - os.environ["FLYTE_MADEUP_STRING_VALUE"] = 'overridden' - assert common.FlyteIntegerConfigurationEntry('madeup', 'int_value').get() == 10 - assert common.FlyteStringConfigurationEntry('madeup', 'string_value').get() == 'overridden' + os.environ["FLYTE_MADEUP_INT_VALUE"] = "10" + os.environ["FLYTE_MADEUP_STRING_VALUE"] = "overridden" + assert common.FlyteIntegerConfigurationEntry("madeup", "int_value").get() == 10 + assert common.FlyteStringConfigurationEntry("madeup", "string_value").get() == "overridden" finally: os.environ.clear() os.environ.update(old_environ) diff --git a/tests/flytekit/unit/configuration/test_internal.py b/tests/flytekit/unit/configuration/test_internal.py index 891b049a40..3d28e47abc 100644 --- a/tests/flytekit/unit/configuration/test_internal.py +++ b/tests/flytekit/unit/configuration/test_internal.py @@ -6,10 +6,10 @@ def test_parsing(): - str = 'somedocker.com/myimage:someversion123' + str = "somedocker.com/myimage:someversion123" version = look_up_version_from_image_tag(str) - assert version == 'someversion123' + assert version == "someversion123" - str = 'ffjdskl/jfkljkdfls' + str = "ffjdskl/jfkljkdfls" with pytest.raises(Exception): look_up_version_from_image_tag(str) diff --git a/tests/flytekit/unit/configuration/test_resources.py b/tests/flytekit/unit/configuration/test_resources.py index 1d117f2294..c50a82cac4 100644 --- a/tests/flytekit/unit/configuration/test_resources.py +++ b/tests/flytekit/unit/configuration/test_resources.py @@ -2,7 +2,7 @@ import os -from flytekit.configuration import set_flyte_config_file, resources +from flytekit.configuration import resources, set_flyte_config_file def test_resource_hints_default(): @@ -17,7 +17,7 @@ def test_resource_hints_default(): def test_resource_hints(): - set_flyte_config_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'configs/good.config')) + set_flyte_config_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/good.config")) assert resources.DEFAULT_CPU_REQUEST.get() == "500m" assert resources.DEFAULT_CPU_LIMIT.get() == "501m" assert resources.DEFAULT_MEMORY_REQUEST.get() == "500Gi" diff --git a/tests/flytekit/unit/configuration/test_temporary_configuration.py b/tests/flytekit/unit/configuration/test_temporary_configuration.py index c569c2bc23..4bf7e8ae95 100644 --- a/tests/flytekit/unit/configuration/test_temporary_configuration.py +++ b/tests/flytekit/unit/configuration/test_temporary_configuration.py @@ -1,37 +1,35 @@ from __future__ import absolute_import -from flytekit.configuration import set_flyte_config_file as _set_flyte_config_file, \ - common as _common, \ - TemporaryConfiguration as _TemporaryConfiguration + import os as _os +from flytekit.configuration import TemporaryConfiguration as _TemporaryConfiguration +from flytekit.configuration import common as _common +from flytekit.configuration import set_flyte_config_file as _set_flyte_config_file + def test_configuration_file(): - with _TemporaryConfiguration( - _os.path.join(_os.path.dirname(_os.path.realpath(__file__)), 'configs/good.config')): - assert _common.CONFIGURATION_SINGLETON.get_string('sdk', 'workflow_packages') == \ - 'this.module,that.module' - assert _common.CONFIGURATION_SINGLETON.get_string('sdk', 'workflow_packages') is None + with _TemporaryConfiguration(_os.path.join(_os.path.dirname(_os.path.realpath(__file__)), "configs/good.config")): + assert _common.CONFIGURATION_SINGLETON.get_string("sdk", "workflow_packages") == "this.module,that.module" + assert _common.CONFIGURATION_SINGLETON.get_string("sdk", "workflow_packages") is None def test_internal_overrides(): with _TemporaryConfiguration( - _os.path.join(_os.path.dirname(_os.path.realpath(__file__)), 'configs/good.config'), - {'foo': 'bar'}): - assert _os.environ.get('FLYTE_INTERNAL_FOO') == 'bar' - assert _os.environ.get('FLYTE_INTERNAL_FOO') is None + _os.path.join(_os.path.dirname(_os.path.realpath(__file__)), "configs/good.config"), {"foo": "bar"}, + ): + assert _os.environ.get("FLYTE_INTERNAL_FOO") == "bar" + assert _os.environ.get("FLYTE_INTERNAL_FOO") is None def test_no_configuration_file(): - _set_flyte_config_file(_os.path.join(_os.path.dirname(_os.path.realpath(__file__)), 'configs/good.config')) + _set_flyte_config_file(_os.path.join(_os.path.dirname(_os.path.realpath(__file__)), "configs/good.config")) with _TemporaryConfiguration(None): - assert _common.CONFIGURATION_SINGLETON.get_string('sdk', 'workflow_packages') is None - assert _common.CONFIGURATION_SINGLETON.get_string('sdk', 'workflow_packages') == \ - 'this.module,that.module' + assert _common.CONFIGURATION_SINGLETON.get_string("sdk", "workflow_packages") is None + assert _common.CONFIGURATION_SINGLETON.get_string("sdk", "workflow_packages") == "this.module,that.module" def test_nonexist_configuration_file(): - _set_flyte_config_file(_os.path.join(_os.path.dirname(_os.path.realpath(__file__)), 'configs/good.config')) - with _TemporaryConfiguration('/foo/bar'): - assert _common.CONFIGURATION_SINGLETON.get_string('sdk', 'workflow_packages') is None - assert _common.CONFIGURATION_SINGLETON.get_string('sdk', 'workflow_packages') == \ - 'this.module,that.module' + _set_flyte_config_file(_os.path.join(_os.path.dirname(_os.path.realpath(__file__)), "configs/good.config")) + with _TemporaryConfiguration("/foo/bar"): + assert _common.CONFIGURATION_SINGLETON.get_string("sdk", "workflow_packages") is None + assert _common.CONFIGURATION_SINGLETON.get_string("sdk", "workflow_packages") == "this.module,that.module" diff --git a/tests/flytekit/unit/configuration/test_waterfall.py b/tests/flytekit/unit/configuration/test_waterfall.py index 0de9317e08..07460e434d 100644 --- a/tests/flytekit/unit/configuration/test_waterfall.py +++ b/tests/flytekit/unit/configuration/test_waterfall.py @@ -1,47 +1,46 @@ from __future__ import absolute_import -from flytekit.configuration import set_flyte_config_file as _set_flyte_config_file, \ - common as _common, \ - TemporaryConfiguration as _TemporaryConfiguration -from flytekit.common.utils import AutoDeletingTempDir as _AutoDeletingTempDir import os as _os +from flytekit.common.utils import AutoDeletingTempDir as _AutoDeletingTempDir +from flytekit.configuration import common as _common + def test_lookup_waterfall_raw_env_var(): - x = _common.FlyteStringConfigurationEntry('test', 'setting', default=None) + x = _common.FlyteStringConfigurationEntry("test", "setting", default=None) - if 'FLYTE_TEST_SETTING' in _os.environ: - del _os.environ['FLYTE_TEST_SETTING'] + if "FLYTE_TEST_SETTING" in _os.environ: + del _os.environ["FLYTE_TEST_SETTING"] assert x.get() is None - _os.environ['FLYTE_TEST_SETTING'] = 'lorem' - assert x.get() == 'lorem' + _os.environ["FLYTE_TEST_SETTING"] = "lorem" + assert x.get() == "lorem" def test_lookup_waterfall_referenced_env_var(): - x = _common.FlyteStringConfigurationEntry('test', 'setting', default=None) + x = _common.FlyteStringConfigurationEntry("test", "setting", default=None) - if 'FLYTE_TEST_SETTING' in _os.environ: - del _os.environ['FLYTE_TEST_SETTING'] + if "FLYTE_TEST_SETTING" in _os.environ: + del _os.environ["FLYTE_TEST_SETTING"] assert x.get() is None - if 'TEMP_PLACEHOLDER' in _os.environ: - del _os.environ['TEMP_PLACEHOLDER'] - _os.environ['TEMP_PLACEHOLDER'] = 'lorem' - _os.environ['FLYTE_TEST_SETTING_FROM_ENV_VAR'] = 'TEMP_PLACEHOLDER' - assert x.get() == 'lorem' + if "TEMP_PLACEHOLDER" in _os.environ: + del _os.environ["TEMP_PLACEHOLDER"] + _os.environ["TEMP_PLACEHOLDER"] = "lorem" + _os.environ["FLYTE_TEST_SETTING_FROM_ENV_VAR"] = "TEMP_PLACEHOLDER" + assert x.get() == "lorem" def test_lookup_waterfall_referenced_file(): - x = _common.FlyteStringConfigurationEntry('test', 'setting', default=None) + x = _common.FlyteStringConfigurationEntry("test", "setting", default=None) - if 'FLYTE_TEST_SETTING' in _os.environ: - del _os.environ['FLYTE_TEST_SETTING'] + if "FLYTE_TEST_SETTING" in _os.environ: + del _os.environ["FLYTE_TEST_SETTING"] assert x.get() is None with _AutoDeletingTempDir("config_testing") as tmp_dir: - with open(tmp_dir.get_named_tempfile('name'), 'w') as fh: - fh.write('secret_password') + with open(tmp_dir.get_named_tempfile("name"), "w") as fh: + fh.write("secret_password") - _os.environ['FLYTE_TEST_SETTING_FROM_FILE'] = tmp_dir.get_named_tempfile('name') - assert x.get() == 'secret_password' + _os.environ["FLYTE_TEST_SETTING_FROM_FILE"] = tmp_dir.get_named_tempfile("name") + assert x.get() == "secret_password" diff --git a/tests/flytekit/unit/contrib/sensors/test_impl.py b/tests/flytekit/unit/contrib/sensors/test_impl.py index 46f1a4e248..f84b4287b3 100644 --- a/tests/flytekit/unit/contrib/sensors/test_impl.py +++ b/tests/flytekit/unit/contrib/sensors/test_impl.py @@ -1,29 +1,22 @@ from __future__ import absolute_import import mock - from hmsclient import HMSClient from hmsclient.genthrift.hive_metastore import ttypes as _ttypes -from flytekit.contrib.sensors.impl import HiveFilteredPartitionSensor -from flytekit.contrib.sensors.impl import HiveNamedPartitionSensor -from flytekit.contrib.sensors.impl import HiveTableSensor +from flytekit.contrib.sensors.impl import HiveFilteredPartitionSensor, HiveNamedPartitionSensor, HiveTableSensor def test_HiveTableSensor(): - hive_table_sensor = HiveTableSensor( - table_name='mocked_table', - host='localhost', - port=1234, - ) - assert hive_table_sensor._schema == 'default' - with mock.patch.object(HMSClient, 'open'): - with mock.patch.object(HMSClient, 'get_table'): + hive_table_sensor = HiveTableSensor(table_name="mocked_table", host="localhost", port=1234) + assert hive_table_sensor._schema == "default" + with mock.patch.object(HMSClient, "open"): + with mock.patch.object(HMSClient, "get_table"): success, interval = hive_table_sensor._do_poll() assert success assert interval is None - with mock.patch.object(HMSClient, 'get_table', side_effect=_ttypes.NoSuchObjectException()): + with mock.patch.object(HMSClient, "get_table", side_effect=_ttypes.NoSuchObjectException()): success, interval = hive_table_sensor._do_poll() assert not success assert interval is None @@ -31,22 +24,18 @@ def test_HiveTableSensor(): def test_HiveNamedPartitionSensor(): hive_named_partition_sensor = HiveNamedPartitionSensor( - table_name='mocked_table', - partition_names=[ - 'ds=2019-10-10', - 'ds=2019-10-11', - ], - host='localhost', - port=1234, + table_name="mocked_table", partition_names=["ds=2019-10-10", "ds=2019-10-11"], host="localhost", port=1234 ) - assert hive_named_partition_sensor._schema == 'default' - with mock.patch.object(HMSClient, 'open'): - with mock.patch.object(HMSClient, 'get_partition_by_name'): + assert hive_named_partition_sensor._schema == "default" + with mock.patch.object(HMSClient, "open"): + with mock.patch.object(HMSClient, "get_partition_by_name"): success, interval = hive_named_partition_sensor._do_poll() assert success assert interval is None - - with mock.patch.object(HMSClient, 'get_partition_by_name', side_effect=_ttypes.NoSuchObjectException()): + + with mock.patch.object( + HMSClient, "get_partition_by_name", side_effect=_ttypes.NoSuchObjectException(), + ): success, interval = hive_named_partition_sensor._do_poll() assert not success assert interval is None @@ -54,19 +43,16 @@ def test_HiveNamedPartitionSensor(): def test_HiveFilteredPartitionSensor(): hive_filtered_partition_sensor = HiveFilteredPartitionSensor( - table_name='mocked_table', - partition_filter="ds = '2019-10-10' AND region = 'NYC'", - host='localhost', - port=1234, + table_name="mocked_table", partition_filter="ds = '2019-10-10' AND region = 'NYC'", host="localhost", port=1234 ) - assert hive_filtered_partition_sensor._schema == 'default' - with mock.patch.object(HMSClient, 'open'): - with mock.patch.object(HMSClient, 'get_partitions_by_filter', return_value=['any']): + assert hive_filtered_partition_sensor._schema == "default" + with mock.patch.object(HMSClient, "open"): + with mock.patch.object(HMSClient, "get_partitions_by_filter", return_value=["any"]): success, interval = hive_filtered_partition_sensor._do_poll() assert success assert interval is None - - with mock.patch.object(HMSClient, 'get_partitions_by_filter', return_value=[]): + + with mock.patch.object(HMSClient, "get_partitions_by_filter", return_value=[]): success, interval = hive_filtered_partition_sensor._do_poll() assert not success assert interval is None diff --git a/tests/flytekit/unit/contrib/sensors/test_task.py b/tests/flytekit/unit/contrib/sensors/test_task.py index ca8d3c3797..0956b2aada 100644 --- a/tests/flytekit/unit/contrib/sensors/test_task.py +++ b/tests/flytekit/unit/contrib/sensors/test_task.py @@ -1,10 +1,8 @@ from flytekit.contrib.sensors.base_sensor import Sensor as _Sensor - from flytekit.contrib.sensors.task import sensor_task class MyMockSensor(_Sensor): - def __init__(self, **kwargs): super(MyMockSensor, self).__init__(**kwargs) diff --git a/tests/flytekit/unit/engines/flyte/test_engine.py b/tests/flytekit/unit/engines/flyte/test_engine.py index 3d56636eb1..e40285e82b 100644 --- a/tests/flytekit/unit/engines/flyte/test_engine.py +++ b/tests/flytekit/unit/engines/flyte/test_engine.py @@ -1,44 +1,41 @@ from __future__ import absolute_import import os + import pytest from flyteidl.core import errors_pb2 -from mock import MagicMock, patch, PropertyMock +from mock import MagicMock, PropertyMock, patch from flytekit.common import constants, utils from flytekit.common.exceptions import scopes from flytekit.configuration import TemporaryConfiguration from flytekit.engines.flyte import engine -from flytekit.models import literals, execution as _execution_models, common as _common_models, launch_plan as \ - _launch_plan_models, task as _task_models +from flytekit.models import common as _common_models +from flytekit.models import execution as _execution_models +from flytekit.models import launch_plan as _launch_plan_models +from flytekit.models import literals +from flytekit.models import task as _task_models from flytekit.models.admin import common as _common from flytekit.models.core import errors, identifier from flytekit.sdk import test_utils - _INPUT_MAP = literals.LiteralMap( - { - 'a': literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=1))) - } + {"a": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=1)))} ) _OUTPUT_MAP = literals.LiteralMap( - { - 'b': literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=2))) - } + {"b": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=2)))} ) @pytest.fixture(scope="function", autouse=True) def temp_config(): with TemporaryConfiguration( - os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../../common/configs/local.config'), - internal_overrides={ - 'image': 'myflyteimage:{}'.format( - os.environ.get('IMAGE_VERSION', 'sha') - ), - 'project': 'myflyteproject', - 'domain': 'development' - } + os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../../common/configs/local.config",), + internal_overrides={ + "image": "myflyteimage:{}".format(os.environ.get("IMAGE_VERSION", "sha")), + "project": "myflyteproject", + "domain": "development", + }, ): yield @@ -52,7 +49,7 @@ def execution_data_locations(): utils.write_proto_to_file(_OUTPUT_MAP.to_flyte_idl(), output_filename) yield ( _common_models.UrlBlob(input_filename, 100), - _common_models.UrlBlob(output_filename, 100) + _common_models.UrlBlob(output_filename, 100), ) @@ -72,10 +69,10 @@ def test_task_system_failure(): m.execute = _raise_system_exception with utils.AutoDeletingTempDir("test") as tmp: - engine.FlyteTask(m).execute(None, {'output_prefix': tmp.name}) + engine.FlyteTask(m).execute(None, {"output_prefix": tmp.name}) doc = errors.ErrorDocument.from_flyte_idl( - utils.load_proto_from_file(errors_pb2.ErrorDocument, os.path.join(tmp.name, constants.ERROR_FILE_NAME)) + utils.load_proto_from_file(errors_pb2.ErrorDocument, os.path.join(tmp.name, constants.ERROR_FILE_NAME),) ) assert doc.error.code == "SYSTEM:Unknown" assert doc.error.kind == errors.ContainerError.Kind.RECOVERABLE @@ -88,145 +85,93 @@ def test_task_user_failure(): m.execute = _raise_user_exception with utils.AutoDeletingTempDir("test") as tmp: - engine.FlyteTask(m).execute(None, {'output_prefix': tmp.name}) + engine.FlyteTask(m).execute(None, {"output_prefix": tmp.name}) doc = errors.ErrorDocument.from_flyte_idl( - utils.load_proto_from_file(errors_pb2.ErrorDocument, os.path.join(tmp.name, constants.ERROR_FILE_NAME)) + utils.load_proto_from_file(errors_pb2.ErrorDocument, os.path.join(tmp.name, constants.ERROR_FILE_NAME),) ) assert doc.error.code == "USER:Unknown" assert doc.error.kind == errors.ContainerError.Kind.NON_RECOVERABLE assert "userUSERuser" in doc.error.message -@patch.object(engine._FlyteClientManager, '_CLIENT', new_callable=PropertyMock) +@patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) def test_execution_notification_overrides(mock_client_factory): mock_client = MagicMock() - mock_client.create_execution = MagicMock(return_value=identifier.WorkflowExecutionIdentifier('xp', 'xd', 'xn')) + mock_client.create_execution = MagicMock(return_value=identifier.WorkflowExecutionIdentifier("xp", "xd", "xn")) mock_client_factory.return_value = mock_client m = MagicMock() type(m).id = PropertyMock( - return_value=identifier.Identifier( - identifier.ResourceType.LAUNCH_PLAN, - "project", - "domain", - "name", - "version" - ) + return_value=identifier.Identifier(identifier.ResourceType.LAUNCH_PLAN, "project", "domain", "name", "version") ) - engine.FlyteLaunchPlan(m).launch( - 'xp', 'xd', 'xn', literals.LiteralMap({}), notification_overrides=[] - ) + engine.FlyteLaunchPlan(m).launch("xp", "xd", "xn", literals.LiteralMap({}), notification_overrides=[]) mock_client.create_execution.assert_called_once_with( - 'xp', - 'xd', - 'xn', + "xp", + "xd", + "xn", _execution_models.ExecutionSpec( - identifier.Identifier( - identifier.ResourceType.LAUNCH_PLAN, - "project", - "domain", - "name", - "version" - ), - _execution_models.ExecutionMetadata( - _execution_models.ExecutionMetadata.ExecutionMode.MANUAL, - 'sdk', - 0 - ), + identifier.Identifier(identifier.ResourceType.LAUNCH_PLAN, "project", "domain", "name", "version",), + _execution_models.ExecutionMetadata(_execution_models.ExecutionMetadata.ExecutionMode.MANUAL, "sdk", 0), disable_all=True, ), literals.LiteralMap({}), ) -@patch.object(engine._FlyteClientManager, '_CLIENT', new_callable=PropertyMock) +@patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) def test_execution_notification_soft_overrides(mock_client_factory): mock_client = MagicMock() - mock_client.create_execution = MagicMock(return_value=identifier.WorkflowExecutionIdentifier('xp', 'xd', 'xn')) + mock_client.create_execution = MagicMock(return_value=identifier.WorkflowExecutionIdentifier("xp", "xd", "xn")) mock_client_factory.return_value = mock_client m = MagicMock() type(m).id = PropertyMock( - return_value=identifier.Identifier( - identifier.ResourceType.LAUNCH_PLAN, - "project", - "domain", - "name", - "version" - ) + return_value=identifier.Identifier(identifier.ResourceType.LAUNCH_PLAN, "project", "domain", "name", "version") ) notification = _common_models.Notification([0, 1, 2], email=_common_models.EmailNotification(["me@place.com"])) - engine.FlyteLaunchPlan(m).launch( - 'xp', 'xd', 'xn', literals.LiteralMap({}), notification_overrides=[notification] - ) + engine.FlyteLaunchPlan(m).launch("xp", "xd", "xn", literals.LiteralMap({}), notification_overrides=[notification]) mock_client.create_execution.assert_called_once_with( - 'xp', - 'xd', - 'xn', + "xp", + "xd", + "xn", _execution_models.ExecutionSpec( - identifier.Identifier( - identifier.ResourceType.LAUNCH_PLAN, - "project", - "domain", - "name", - "version" - ), - _execution_models.ExecutionMetadata( - _execution_models.ExecutionMetadata.ExecutionMode.MANUAL, - 'sdk', - 0 - ), + identifier.Identifier(identifier.ResourceType.LAUNCH_PLAN, "project", "domain", "name", "version",), + _execution_models.ExecutionMetadata(_execution_models.ExecutionMetadata.ExecutionMode.MANUAL, "sdk", 0), notifications=_execution_models.NotificationList([notification]), ), literals.LiteralMap({}), ) -@patch.object(engine._FlyteClientManager, '_CLIENT', new_callable=PropertyMock) +@patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) def test_execution_label_overrides(mock_client_factory): mock_client = MagicMock() - mock_client.create_execution = MagicMock(return_value=identifier.WorkflowExecutionIdentifier('xp', 'xd', 'xn')) + mock_client.create_execution = MagicMock(return_value=identifier.WorkflowExecutionIdentifier("xp", "xd", "xn")) mock_client_factory.return_value = mock_client m = MagicMock() type(m).id = PropertyMock( - return_value=identifier.Identifier( - identifier.ResourceType.LAUNCH_PLAN, - "project", - "domain", - "name", - "version" - ) + return_value=identifier.Identifier(identifier.ResourceType.LAUNCH_PLAN, "project", "domain", "name", "version") ) labels = _common_models.Labels({"my": "label"}) engine.FlyteLaunchPlan(m).execute( - 'xp', 'xd', 'xn', literals.LiteralMap({}), notification_overrides=[], label_overrides=labels + "xp", "xd", "xn", literals.LiteralMap({}), notification_overrides=[], label_overrides=labels, ) mock_client.create_execution.assert_called_once_with( - 'xp', - 'xd', - 'xn', + "xp", + "xd", + "xn", _execution_models.ExecutionSpec( - identifier.Identifier( - identifier.ResourceType.LAUNCH_PLAN, - "project", - "domain", - "name", - "version" - ), - _execution_models.ExecutionMetadata( - _execution_models.ExecutionMetadata.ExecutionMode.MANUAL, - 'sdk', - 0 - ), + identifier.Identifier(identifier.ResourceType.LAUNCH_PLAN, "project", "domain", "name", "version",), + _execution_models.ExecutionMetadata(_execution_models.ExecutionMetadata.ExecutionMode.MANUAL, "sdk", 0), disable_all=True, labels=labels, ), @@ -234,45 +179,29 @@ def test_execution_label_overrides(mock_client_factory): ) -@patch.object(engine._FlyteClientManager, '_CLIENT', new_callable=PropertyMock) +@patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) def test_execution_annotation_overrides(mock_client_factory): mock_client = MagicMock() - mock_client.create_execution = MagicMock(return_value=identifier.WorkflowExecutionIdentifier('xp', 'xd', 'xn')) + mock_client.create_execution = MagicMock(return_value=identifier.WorkflowExecutionIdentifier("xp", "xd", "xn")) mock_client_factory.return_value = mock_client m = MagicMock() type(m).id = PropertyMock( - return_value=identifier.Identifier( - identifier.ResourceType.LAUNCH_PLAN, - "project", - "domain", - "name", - "version" - ) + return_value=identifier.Identifier(identifier.ResourceType.LAUNCH_PLAN, "project", "domain", "name", "version") ) annotations = _common_models.Annotations({"my": "annotation"}) engine.FlyteLaunchPlan(m).launch( - 'xp', 'xd', 'xn', literals.LiteralMap({}), notification_overrides=[], annotation_overrides=annotations + "xp", "xd", "xn", literals.LiteralMap({}), notification_overrides=[], annotation_overrides=annotations, ) mock_client.create_execution.assert_called_once_with( - 'xp', - 'xd', - 'xn', + "xp", + "xd", + "xn", _execution_models.ExecutionSpec( - identifier.Identifier( - identifier.ResourceType.LAUNCH_PLAN, - "project", - "domain", - "name", - "version" - ), - _execution_models.ExecutionMetadata( - _execution_models.ExecutionMetadata.ExecutionMode.MANUAL, - 'sdk', - 0 - ), + identifier.Identifier(identifier.ResourceType.LAUNCH_PLAN, "project", "domain", "name", "version",), + _execution_models.ExecutionMetadata(_execution_models.ExecutionMetadata.ExecutionMode.MANUAL, "sdk", 0), disable_all=True, annotations=annotations, ), @@ -280,7 +209,7 @@ def test_execution_annotation_overrides(mock_client_factory): ) -@patch.object(engine._FlyteClientManager, '_CLIENT', new_callable=PropertyMock) +@patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) def test_fetch_launch_plan(mock_client_factory): mock_client = MagicMock() mock_client.get_launch_plan = MagicMock( @@ -302,7 +231,7 @@ def test_fetch_launch_plan(mock_client_factory): ) -@patch.object(engine._FlyteClientManager, '_CLIENT', new_callable=PropertyMock) +@patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) def test_fetch_active_launch_plan(mock_client_factory): mock_client = MagicMock() mock_client.get_active_launch_plan = MagicMock( @@ -319,74 +248,57 @@ def test_fetch_active_launch_plan(mock_client_factory): ) assert lp.id == identifier.Identifier(identifier.ResourceType.LAUNCH_PLAN, "p1", "d1", "n1", "v1") - mock_client.get_active_launch_plan.assert_called_once_with( - _common_models.NamedEntityIdentifier("p", "d", "n") - ) + mock_client.get_active_launch_plan.assert_called_once_with(_common_models.NamedEntityIdentifier("p", "d", "n")) -@patch.object(engine._FlyteClientManager, '_CLIENT', new_callable=PropertyMock) +@patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) def test_get_execution_inputs(mock_client_factory, execution_data_locations): mock_client = MagicMock() mock_client.get_execution_data = MagicMock( return_value=_execution_models.WorkflowExecutionGetDataResponse( - execution_data_locations[0], - execution_data_locations[1] + execution_data_locations[0], execution_data_locations[1] ) ) mock_client_factory.return_value = mock_client m = MagicMock() - type(m).id = PropertyMock( - return_value=identifier.WorkflowExecutionIdentifier( - "project", - "domain", - "name", - ) - ) + type(m).id = PropertyMock(return_value=identifier.WorkflowExecutionIdentifier("project", "domain", "name",)) inputs = engine.FlyteWorkflowExecution(m).get_inputs() assert len(inputs.literals) == 1 - assert inputs.literals['a'].scalar.primitive.integer == 1 + assert inputs.literals["a"].scalar.primitive.integer == 1 mock_client.get_execution_data.assert_called_once_with( identifier.WorkflowExecutionIdentifier("project", "domain", "name") ) -@patch.object(engine._FlyteClientManager, '_CLIENT', new_callable=PropertyMock) +@patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) def test_get_execution_outputs(mock_client_factory, execution_data_locations): mock_client = MagicMock() mock_client.get_execution_data = MagicMock( return_value=_execution_models.WorkflowExecutionGetDataResponse( - execution_data_locations[0], - execution_data_locations[1] + execution_data_locations[0], execution_data_locations[1] ) ) mock_client_factory.return_value = mock_client m = MagicMock() - type(m).id = PropertyMock( - return_value=identifier.WorkflowExecutionIdentifier( - "project", - "domain", - "name", - ) - ) + type(m).id = PropertyMock(return_value=identifier.WorkflowExecutionIdentifier("project", "domain", "name",)) inputs = engine.FlyteWorkflowExecution(m).get_outputs() assert len(inputs.literals) == 1 - assert inputs.literals['b'].scalar.primitive.integer == 2 + assert inputs.literals["b"].scalar.primitive.integer == 2 mock_client.get_execution_data.assert_called_once_with( identifier.WorkflowExecutionIdentifier("project", "domain", "name") ) -@patch.object(engine._FlyteClientManager, '_CLIENT', new_callable=PropertyMock) +@patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) def test_get_node_execution_inputs(mock_client_factory, execution_data_locations): mock_client = MagicMock() mock_client.get_node_execution_data = MagicMock( return_value=_execution_models.NodeExecutionGetDataResponse( - execution_data_locations[0], - execution_data_locations[1] + execution_data_locations[0], execution_data_locations[1] ) ) mock_client_factory.return_value = mock_client @@ -394,37 +306,26 @@ def test_get_node_execution_inputs(mock_client_factory, execution_data_locations m = MagicMock() type(m).id = PropertyMock( return_value=identifier.NodeExecutionIdentifier( - "node-a", - identifier.WorkflowExecutionIdentifier( - "project", - "domain", - "name", - ) + "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), ) ) inputs = engine.FlyteNodeExecution(m).get_inputs() assert len(inputs.literals) == 1 - assert inputs.literals['a'].scalar.primitive.integer == 1 + assert inputs.literals["a"].scalar.primitive.integer == 1 mock_client.get_node_execution_data.assert_called_once_with( identifier.NodeExecutionIdentifier( - "node-a", - identifier.WorkflowExecutionIdentifier( - "project", - "domain", - "name", - ) + "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), ) ) -@patch.object(engine._FlyteClientManager, '_CLIENT', new_callable=PropertyMock) +@patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) def test_get_node_execution_outputs(mock_client_factory, execution_data_locations): mock_client = MagicMock() mock_client.get_node_execution_data = MagicMock( return_value=_execution_models.NodeExecutionGetDataResponse( - execution_data_locations[0], - execution_data_locations[1] + execution_data_locations[0], execution_data_locations[1] ) ) mock_client_factory.return_value = mock_client @@ -432,37 +333,26 @@ def test_get_node_execution_outputs(mock_client_factory, execution_data_location m = MagicMock() type(m).id = PropertyMock( return_value=identifier.NodeExecutionIdentifier( - "node-a", - identifier.WorkflowExecutionIdentifier( - "project", - "domain", - "name", - ) + "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), ) ) inputs = engine.FlyteNodeExecution(m).get_outputs() assert len(inputs.literals) == 1 - assert inputs.literals['b'].scalar.primitive.integer == 2 + assert inputs.literals["b"].scalar.primitive.integer == 2 mock_client.get_node_execution_data.assert_called_once_with( identifier.NodeExecutionIdentifier( - "node-a", - identifier.WorkflowExecutionIdentifier( - "project", - "domain", - "name", - ) + "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), ) ) -@patch.object(engine._FlyteClientManager, '_CLIENT', new_callable=PropertyMock) +@patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) def test_get_task_execution_inputs(mock_client_factory, execution_data_locations): mock_client = MagicMock() mock_client.get_task_execution_data = MagicMock( return_value=_execution_models.TaskExecutionGetDataResponse( - execution_data_locations[0], - execution_data_locations[1] + execution_data_locations[0], execution_data_locations[1] ) ) mock_client_factory.return_value = mock_client @@ -470,57 +360,34 @@ def test_get_task_execution_inputs(mock_client_factory, execution_data_locations m = MagicMock() type(m).id = PropertyMock( return_value=identifier.TaskExecutionIdentifier( - identifier.Identifier( - identifier.ResourceType.TASK, - 'project', - 'domain', - 'task-name', - 'version' - ), + identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "task-name", "version",), identifier.NodeExecutionIdentifier( - "node-a", - identifier.WorkflowExecutionIdentifier( - "project", - "domain", - "name", - ) + "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), ), - 0 + 0, ) ) inputs = engine.FlyteTaskExecution(m).get_inputs() assert len(inputs.literals) == 1 - assert inputs.literals['a'].scalar.primitive.integer == 1 + assert inputs.literals["a"].scalar.primitive.integer == 1 mock_client.get_task_execution_data.assert_called_once_with( identifier.TaskExecutionIdentifier( - identifier.Identifier( - identifier.ResourceType.TASK, - 'project', - 'domain', - 'task-name', - 'version' - ), + identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "task-name", "version",), identifier.NodeExecutionIdentifier( - "node-a", - identifier.WorkflowExecutionIdentifier( - "project", - "domain", - "name", - ) + "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), ), - 0 + 0, ) ) -@patch.object(engine._FlyteClientManager, '_CLIENT', new_callable=PropertyMock) +@patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) def test_get_task_execution_outputs(mock_client_factory, execution_data_locations): mock_client = MagicMock() mock_client.get_task_execution_data = MagicMock( return_value=_execution_models.TaskExecutionGetDataResponse( - execution_data_locations[0], - execution_data_locations[1] + execution_data_locations[0], execution_data_locations[1] ) ) mock_client_factory.return_value = mock_client @@ -528,66 +395,42 @@ def test_get_task_execution_outputs(mock_client_factory, execution_data_location m = MagicMock() type(m).id = PropertyMock( return_value=identifier.TaskExecutionIdentifier( - identifier.Identifier( - identifier.ResourceType.TASK, - 'project', - 'domain', - 'task-name', - 'version' - ), + identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "task-name", "version",), identifier.NodeExecutionIdentifier( - "node-a", - identifier.WorkflowExecutionIdentifier( - "project", - "domain", - "name", - ) + "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), ), - 0 + 0, ) ) inputs = engine.FlyteTaskExecution(m).get_outputs() assert len(inputs.literals) == 1 - assert inputs.literals['b'].scalar.primitive.integer == 2 + assert inputs.literals["b"].scalar.primitive.integer == 2 mock_client.get_task_execution_data.assert_called_once_with( identifier.TaskExecutionIdentifier( - identifier.Identifier( - identifier.ResourceType.TASK, - 'project', - 'domain', - 'task-name', - 'version' - ), + identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "task-name", "version",), identifier.NodeExecutionIdentifier( - "node-a", - identifier.WorkflowExecutionIdentifier( - "project", - "domain", - "name", - ) + "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), ), - 0 + 0, ) ) -@pytest.mark.parametrize("tasks", - [[_task_models.Task( - identifier.Identifier(identifier.ResourceType.TASK, "p1", "d1", "n1", "v1"), - MagicMock(), - )], []]) -@patch.object(engine._FlyteClientManager, '_CLIENT', new_callable=PropertyMock) +@pytest.mark.parametrize( + "tasks", + [ + [_task_models.Task(identifier.Identifier(identifier.ResourceType.TASK, "p1", "d1", "n1", "v1"), MagicMock(),)], + [], + ], +) +@patch.object(engine._FlyteClientManager, "_CLIENT", new_callable=PropertyMock) def test_fetch_latest_task(mock_client_factory, tasks): mock_client = MagicMock() - mock_client.list_tasks_paginated = MagicMock( - return_value=(tasks, 0) - ) + mock_client.list_tasks_paginated = MagicMock(return_value=(tasks, 0)) mock_client_factory.return_value = mock_client - task = engine.FlyteEngineFactory().fetch_latest_task( - _common_models.NamedEntityIdentifier("p", "d", "n") - ) + task = engine.FlyteEngineFactory().fetch_latest_task(_common_models.NamedEntityIdentifier("p", "d", "n")) if tasks: assert task.id == tasks[0].id @@ -597,5 +440,5 @@ def test_fetch_latest_task(mock_client_factory, tasks): mock_client.list_tasks_paginated.assert_called_once_with( _common_models.NamedEntityIdentifier("p", "d", "n"), limit=1, - sort_by=_common.Sort("created_at", _common.Sort.Direction.DESCENDING) + sort_by=_common.Sort("created_at", _common.Sort.Direction.DESCENDING), ) diff --git a/tests/flytekit/unit/engines/test_loader.py b/tests/flytekit/unit/engines/test_loader.py index b16209dc26..248e35dd18 100644 --- a/tests/flytekit/unit/engines/test_loader.py +++ b/tests/flytekit/unit/engines/test_loader.py @@ -1,13 +1,15 @@ from __future__ import absolute_import + +import pytest + from flytekit.engines import loader from flytekit.engines.unit import engine as _unit_engine -import pytest def test_unit_load(): - assert isinstance(loader.get_engine('unit'), _unit_engine.UnitTestEngineFactory) + assert isinstance(loader.get_engine("unit"), _unit_engine.UnitTestEngineFactory) def test_bad_load(): with pytest.raises(Exception): - loader.get_engine('badname') + loader.get_engine("badname") diff --git a/tests/flytekit/unit/interfaces/data/gcs/test_gcs_proxy.py b/tests/flytekit/unit/interfaces/data/gcs/test_gcs_proxy.py index f32b7c09ca..d1b29232d1 100644 --- a/tests/flytekit/unit/interfaces/data/gcs/test_gcs_proxy.py +++ b/tests/flytekit/unit/interfaces/data/gcs/test_gcs_proxy.py @@ -4,6 +4,7 @@ import mock as _mock import pytest as _pytest + from flytekit.interfaces.data.gcs import gcs_proxy as _gcs_proxy @@ -29,14 +30,10 @@ def gcs_proxy(): def test_upload_directory(mock_update_cmd_config_and_execute, gcs_proxy): local_path, remote_path = "/foo/*", "gs://bar/0/" gcs_proxy.upload_directory(local_path, remote_path) - mock_update_cmd_config_and_execute.assert_called_once_with( - ["gsutil", "cp", "-r", local_path, remote_path] - ) + mock_update_cmd_config_and_execute.assert_called_once_with(["gsutil", "cp", "-r", local_path, remote_path]) -def test_upload_directory_padding_wildcard_for_local_path( - mock_update_cmd_config_and_execute, gcs_proxy -): +def test_upload_directory_padding_wildcard_for_local_path(mock_update_cmd_config_and_execute, gcs_proxy): local_path, remote_path = "/foo", "gs://bar/0/" gcs_proxy.upload_directory(local_path, remote_path) mock_update_cmd_config_and_execute.assert_called_once_with( @@ -44,14 +41,10 @@ def test_upload_directory_padding_wildcard_for_local_path( ) -def test_upload_directory_padding_slash_for_remote_path( - mock_update_cmd_config_and_execute, gcs_proxy -): +def test_upload_directory_padding_slash_for_remote_path(mock_update_cmd_config_and_execute, gcs_proxy): local_path, remote_path = "/foo/*", "gs://bar/0" gcs_proxy.upload_directory(local_path, remote_path) - mock_update_cmd_config_and_execute.assert_called_once_with( - ["gsutil", "cp", "-r", local_path, remote_path + "/"] - ) + mock_update_cmd_config_and_execute.assert_called_once_with(["gsutil", "cp", "-r", local_path, remote_path + "/"]) def test_maybe_with_gsutil_parallelism_disabled(gcs_proxy): @@ -66,21 +59,13 @@ def test_maybe_with_gsutil_parallelism_enabled(gsutil_parallelism, gcs_proxy): assert cmd == ["gsutil", "-m", "cp", "-r", local_path, remote_path] -def test_download_with_parallelism( - mock_update_cmd_config_and_execute, gsutil_parallelism, gcs_proxy -): +def test_download_with_parallelism(mock_update_cmd_config_and_execute, gsutil_parallelism, gcs_proxy): local_path, remote_path = "/foo", "gs://bar/0/" gcs_proxy.download(remote_path, local_path) - mock_update_cmd_config_and_execute.assert_called_once_with( - ["gsutil", "-m", "cp", remote_path, local_path] - ) + mock_update_cmd_config_and_execute.assert_called_once_with(["gsutil", "-m", "cp", remote_path, local_path]) -def test_upload_directory_with_parallelism( - mock_update_cmd_config_and_execute, gsutil_parallelism, gcs_proxy -): +def test_upload_directory_with_parallelism(mock_update_cmd_config_and_execute, gsutil_parallelism, gcs_proxy): local_path, remote_path = "/foo/*", "gs://bar/0/" gcs_proxy.upload_directory(local_path, remote_path) - mock_update_cmd_config_and_execute.assert_called_once_with( - ["gsutil", "-m", "cp", "-r", local_path, remote_path] - ) + mock_update_cmd_config_and_execute.assert_called_once_with(["gsutil", "-m", "cp", "-r", local_path, remote_path]) diff --git a/tests/flytekit/unit/interfaces/test_random.py b/tests/flytekit/unit/interfaces/test_random.py index eacddbf38d..ee247a515c 100644 --- a/tests/flytekit/unit/interfaces/test_random.py +++ b/tests/flytekit/unit/interfaces/test_random.py @@ -1,7 +1,9 @@ from __future__ import absolute_import -from flytekit.interfaces import random + import random as global_random +from flytekit.interfaces import random + def test_isolated_random_state(): random.seed_flyte_random("abc") diff --git a/tests/flytekit/unit/models/admin/test_common.py b/tests/flytekit/unit/models/admin/test_common.py index 0eea504f8a..5ff6fc54f2 100644 --- a/tests/flytekit/unit/models/admin/test_common.py +++ b/tests/flytekit/unit/models/admin/test_common.py @@ -1,31 +1,32 @@ from __future__ import absolute_import -from flytekit.models.admin import common as _common import pytest as _pytest +from flytekit.models.admin import common as _common + def test_sort(): - o = _common.Sort(key='abc', direction=_common.Sort.Direction.ASCENDING) - assert o.key == 'abc' + o = _common.Sort(key="abc", direction=_common.Sort.Direction.ASCENDING) + assert o.key == "abc" assert o.direction == _common.Sort.Direction.ASCENDING o2 = _common.Sort.from_flyte_idl(o.to_flyte_idl()) assert o2 == o - assert o2.key == 'abc' + assert o2.key == "abc" assert o2.direction == _common.Sort.Direction.ASCENDING def test_sort_parse(): - o = _common.Sort.from_python_std(' asc(my"\wackyk3y) ') - assert o.key == 'my"\wackyk3y' + o = _common.Sort.from_python_std(' asc(my"\wackyk3y) ') # noqa: W605 + assert o.key == 'my"\wackyk3y' # noqa: W605 assert o.direction == _common.Sort.Direction.ASCENDING - o = _common.Sort.from_python_std(' desc( mykey ) ') - assert o.key == 'mykey' + o = _common.Sort.from_python_std(" desc( mykey ) ") + assert o.key == "mykey" assert o.direction == _common.Sort.Direction.DESCENDING with _pytest.raises(ValueError): - _common.Sort.from_python_std('asc(abc') + _common.Sort.from_python_std("asc(abc") with _pytest.raises(ValueError): - _common.Sort.from_python_std('asce(abc)') + _common.Sort.from_python_std("asce(abc)") diff --git a/tests/flytekit/unit/models/core/test_errors.py b/tests/flytekit/unit/models/core/test_errors.py index 81e3b3e2a0..935886b763 100644 --- a/tests/flytekit/unit/models/core/test_errors.py +++ b/tests/flytekit/unit/models/core/test_errors.py @@ -1,22 +1,23 @@ from __future__ import absolute_import + from flytekit.models.core import errors def test_container_error(): - obj = errors.ContainerError('code', 'my message', errors.ContainerError.Kind.RECOVERABLE) - assert obj.code == 'code' - assert obj.message == 'my message' + obj = errors.ContainerError("code", "my message", errors.ContainerError.Kind.RECOVERABLE) + assert obj.code == "code" + assert obj.message == "my message" assert obj.kind == errors.ContainerError.Kind.RECOVERABLE obj2 = errors.ContainerError.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 - assert obj2.code == 'code' - assert obj2.message == 'my message' + assert obj2.code == "code" + assert obj2.message == "my message" assert obj2.kind == errors.ContainerError.Kind.RECOVERABLE def test_error_document(): - ce = errors.ContainerError('code', 'my message', errors.ContainerError.Kind.RECOVERABLE) + ce = errors.ContainerError("code", "my message", errors.ContainerError.Kind.RECOVERABLE) obj = errors.ErrorDocument(ce) assert obj.error == ce diff --git a/tests/flytekit/unit/models/core/test_execution.py b/tests/flytekit/unit/models/core/test_execution.py index ab01c4eab7..c8bc7839f4 100644 --- a/tests/flytekit/unit/models/core/test_execution.py +++ b/tests/flytekit/unit/models/core/test_execution.py @@ -1,7 +1,9 @@ from __future__ import absolute_import -from flytekit.models.core import execution + import datetime +from flytekit.models.core import execution + def test_task_logs(): obj = execution.TaskLog("uri", "name", execution.TaskLog.MessageFormat.CSV, datetime.timedelta(days=30)) diff --git a/tests/flytekit/unit/models/core/test_identifier.py b/tests/flytekit/unit/models/core/test_identifier.py index a5fe6cdf00..633aa1df66 100644 --- a/tests/flytekit/unit/models/core/test_identifier.py +++ b/tests/flytekit/unit/models/core/test_identifier.py @@ -1,4 +1,5 @@ from __future__ import absolute_import + from flytekit.models.core import identifier @@ -21,10 +22,7 @@ def test_identifier(): def test_node_execution_identifier(): wf_exec_id = identifier.WorkflowExecutionIdentifier("project", "domain", "name") - obj = identifier.NodeExecutionIdentifier( - "node_id", - wf_exec_id - ) + obj = identifier.NodeExecutionIdentifier("node_id", wf_exec_id) assert obj.node_id == "node_id" assert obj.execution_id == wf_exec_id @@ -37,10 +35,7 @@ def test_node_execution_identifier(): def test_task_execution_identifier(): task_id = identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version") wf_exec_id = identifier.WorkflowExecutionIdentifier("project", "domain", "name") - node_exec_id = identifier.NodeExecutionIdentifier( - "node_id", - wf_exec_id, - ) + node_exec_id = identifier.NodeExecutionIdentifier("node_id", wf_exec_id,) obj = identifier.TaskExecutionIdentifier(task_id, node_exec_id, 3) assert obj.retry_attempt == 3 assert obj.task_id == task_id @@ -65,7 +60,8 @@ def test_workflow_execution_identifier(): assert obj2.domain == "domain" assert obj2.name == "name" -def test_task_execution_identifier(): + +def test_identifier_emptiness(): empty_id = identifier.Identifier(identifier.ResourceType.UNSPECIFIED, "", "", "", "") not_empty_id = identifier.Identifier(identifier.ResourceType.UNSPECIFIED, "", "", "", "version") assert empty_id.is_empty diff --git a/tests/flytekit/unit/models/core/test_types.py b/tests/flytekit/unit/models/core/test_types.py index 866b7b57a7..55ce9abf16 100644 --- a/tests/flytekit/unit/models/core/test_types.py +++ b/tests/flytekit/unit/models/core/test_types.py @@ -1,7 +1,9 @@ from __future__ import absolute_import -from flytekit.models.core import types as _types + from flyteidl.core import types_pb2 as _types_pb2 +from flytekit.models.core import types as _types + def test_blob_dimensionality(): assert _types.BlobType.BlobDimensionality.SINGLE == _types_pb2.BlobType.SINGLE @@ -9,10 +11,7 @@ def test_blob_dimensionality(): def test_blob_type(): - o = _types.BlobType( - format="csv", - dimensionality=_types.BlobType.BlobDimensionality.SINGLE, - ) + o = _types.BlobType(format="csv", dimensionality=_types.BlobType.BlobDimensionality.SINGLE,) assert o.format == "csv" assert o.dimensionality == _types.BlobType.BlobDimensionality.SINGLE diff --git a/tests/flytekit/unit/models/core/test_workflow.py b/tests/flytekit/unit/models/core/test_workflow.py index 8a3cb26cae..c906c3f40a 100644 --- a/tests/flytekit/unit/models/core/test_workflow.py +++ b/tests/flytekit/unit/models/core/test_workflow.py @@ -2,14 +2,18 @@ from datetime import timedelta -from flytekit.models import literals as _literals, types as _types, interface as _interface -from flytekit.models.core import workflow as _workflow, identifier as _identifier, condition as _condition +from flytekit.models import interface as _interface +from flytekit.models import literals as _literals +from flytekit.models import types as _types +from flytekit.models.core import condition as _condition +from flytekit.models.core import identifier as _identifier +from flytekit.models.core import workflow as _workflow _generic_id = _identifier.Identifier(_identifier.ResourceType.WORKFLOW, "project", "domain", "name", "version") def test_node_metadata(): - obj = _workflow.NodeMetadata(name='node1', timeout=timedelta(seconds=10), retries=_literals.RetryStrategy(0)) + obj = _workflow.NodeMetadata(name="node1", timeout=timedelta(seconds=10), retries=_literals.RetryStrategy(0)) assert obj.timeout.seconds == 10 assert obj.retries.retries == 0 obj2 = _workflow.NodeMetadata.from_flyte_idl(obj.to_flyte_idl()) @@ -19,13 +23,13 @@ def test_node_metadata(): def test_alias(): - obj = _workflow.Alias(var='myvar', alias='myalias') - assert obj.alias == 'myalias' - assert obj.var == 'myvar' + obj = _workflow.Alias(var="myvar", alias="myalias") + assert obj.alias == "myalias" + assert obj.var == "myvar" obj2 = _workflow.Alias.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 - assert obj2.alias == 'myalias' - assert obj2.var == 'myvar' + assert obj2.alias == "myalias" + assert obj2.var == "myvar" def test_workflow_template(): @@ -35,19 +39,11 @@ def test_workflow_template(): wf_metadata = _workflow.WorkflowMetadata() wf_metadata_defaults = _workflow.WorkflowMetadataDefaults() typed_interface = _interface.TypedInterface( - {'a': _interface.Variable(int_type, "description1")}, - { - 'b': _interface.Variable(int_type, "description2"), - 'c': _interface.Variable(int_type, "description3") - } + {"a": _interface.Variable(int_type, "description1")}, + {"b": _interface.Variable(int_type, "description2"), "c": _interface.Variable(int_type, "description3")}, ) wf_node = _workflow.Node( - id='some:node:id', - metadata=nm, - inputs=[], - upstream_node_ids=[], - output_aliases=[], - task_node=task + id="some:node:id", metadata=nm, inputs=[], upstream_node_ids=[], output_aliases=[], task_node=task, ) obj = _workflow.WorkflowTemplate( id=_generic_id, @@ -55,14 +51,16 @@ def test_workflow_template(): metadata_defaults=wf_metadata_defaults, interface=typed_interface, nodes=[wf_node], - outputs=[]) + outputs=[], + ) obj2 = _workflow.WorkflowTemplate.from_flyte_idl(obj.to_flyte_idl()) assert obj2 == obj def test_workflow_metadata_failure_policy(): obj = _workflow.WorkflowMetadata( - on_failure=_workflow.WorkflowMetadata.OnFailurePolicy.FAIL_AFTER_EXECUTABLE_NODES_COMPLETE) + on_failure=_workflow.WorkflowMetadata.OnFailurePolicy.FAIL_AFTER_EXECUTABLE_NODES_COMPLETE + ) obj2 = _workflow.WorkflowMetadata.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 assert obj.on_failure == _workflow.WorkflowMetadata.OnFailurePolicy.FAIL_AFTER_EXECUTABLE_NODES_COMPLETE @@ -107,7 +105,7 @@ def test_workflow_node_sw(): def _get_sample_node_metadata(): - return _workflow.NodeMetadata(name='node1', timeout=timedelta(seconds=10), retries=_literals.RetryStrategy(0)) + return _workflow.NodeMetadata(name="node1", timeout=timedelta(seconds=10), retries=_literals.RetryStrategy(0)) def test_node_task_with_no_inputs(): @@ -115,21 +113,16 @@ def test_node_task_with_no_inputs(): task = _workflow.TaskNode(reference_id=_generic_id) obj = _workflow.Node( - id='some:node:id', - metadata=nm, - inputs=[], - upstream_node_ids=[], - output_aliases=[], - task_node=task + id="some:node:id", metadata=nm, inputs=[], upstream_node_ids=[], output_aliases=[], task_node=task, ) assert obj.target == task - assert obj.id == 'some:node:id' + assert obj.id == "some:node:id" assert obj.metadata == nm obj2 = _workflow.Node.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 assert obj2.target == task - assert obj2.id == 'some:node:id' + assert obj2.id == "some:node:id" assert obj2.metadata == nm @@ -138,19 +131,19 @@ def test_node_task_with_inputs(): task = _workflow.TaskNode(reference_id=_generic_id) bd = _literals.BindingData(scalar=_literals.Scalar(primitive=_literals.Primitive(integer=5))) bd2 = _literals.BindingData(scalar=_literals.Scalar(primitive=_literals.Primitive(integer=99))) - binding = _literals.Binding(var='myvar', binding=bd) - binding2 = _literals.Binding(var='myothervar', binding=bd2) + binding = _literals.Binding(var="myvar", binding=bd) + binding2 = _literals.Binding(var="myothervar", binding=bd2) obj = _workflow.Node( - id='some:node:id', + id="some:node:id", metadata=nm, inputs=[binding, binding2], upstream_node_ids=[], output_aliases=[], - task_node=task + task_node=task, ) assert obj.target == task - assert obj.id == 'some:node:id' + assert obj.id == "some:node:id" assert obj.metadata == nm assert len(obj.inputs) == 2 assert obj.inputs[0] == binding @@ -158,7 +151,7 @@ def test_node_task_with_inputs(): obj2 = _workflow.Node.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 assert obj2.target == task - assert obj2.id == 'some:node:id' + assert obj2.id == "some:node:id" assert obj2.metadata == nm assert len(obj2.inputs) == 2 assert obj2.inputs[1] == binding2 @@ -169,48 +162,57 @@ def test_branch_node(): task = _workflow.TaskNode(reference_id=_generic_id) bd = _literals.BindingData(scalar=_literals.Scalar(primitive=_literals.Primitive(integer=5))) bd2 = _literals.BindingData(scalar=_literals.Scalar(primitive=_literals.Primitive(integer=99))) - binding = _literals.Binding(var='myvar', binding=bd) - binding2 = _literals.Binding(var='myothervar', binding=bd2) + binding = _literals.Binding(var="myvar", binding=bd) + binding2 = _literals.Binding(var="myothervar", binding=bd2) obj = _workflow.Node( - id='some:node:id', + id="some:node:id", metadata=nm, inputs=[binding, binding2], upstream_node_ids=[], output_aliases=[], - task_node=task + task_node=task, ) - bn = _workflow.BranchNode(_workflow.IfElseBlock( - case=_workflow.IfBlock( - condition=_condition.BooleanExpression( - comparison=_condition.ComparisonExpression(_condition.ComparisonExpression.Operator.EQ, - _condition.Operand(primitive=_literals.Primitive(integer=5)), - _condition.Operand( - primitive=_literals.Primitive(integer=2)))), - then_node=obj - ), - other=[_workflow.IfBlock( - condition=_condition.BooleanExpression( - conjunction=_condition.ConjunctionExpression(_condition.ConjunctionExpression.LogicalOperator.AND, - _condition.BooleanExpression( - comparison=_condition.ComparisonExpression( - _condition.ComparisonExpression.Operator.EQ, - _condition.Operand( - primitive=_literals.Primitive(integer=5)), - _condition.Operand( - primitive=_literals.Primitive(integer=2)))), - _condition.BooleanExpression( - comparison=_condition.ComparisonExpression( - _condition.ComparisonExpression.Operator.EQ, - _condition.Operand( - primitive=_literals.Primitive(integer=5)), - _condition.Operand( - primitive=_literals.Primitive(integer=2)))))), - then_node=obj - )], - else_node=obj - )) + bn = _workflow.BranchNode( + _workflow.IfElseBlock( + case=_workflow.IfBlock( + condition=_condition.BooleanExpression( + comparison=_condition.ComparisonExpression( + _condition.ComparisonExpression.Operator.EQ, + _condition.Operand(primitive=_literals.Primitive(integer=5)), + _condition.Operand(primitive=_literals.Primitive(integer=2)), + ) + ), + then_node=obj, + ), + other=[ + _workflow.IfBlock( + condition=_condition.BooleanExpression( + conjunction=_condition.ConjunctionExpression( + _condition.ConjunctionExpression.LogicalOperator.AND, + _condition.BooleanExpression( + comparison=_condition.ComparisonExpression( + _condition.ComparisonExpression.Operator.EQ, + _condition.Operand(primitive=_literals.Primitive(integer=5)), + _condition.Operand(primitive=_literals.Primitive(integer=2)), + ) + ), + _condition.BooleanExpression( + comparison=_condition.ComparisonExpression( + _condition.ComparisonExpression.Operator.EQ, + _condition.Operand(primitive=_literals.Primitive(integer=5)), + _condition.Operand(primitive=_literals.Primitive(integer=2)), + ) + ), + ) + ), + then_node=obj, + ) + ], + else_node=obj, + ) + ) bn2 = _workflow.BranchNode.from_flyte_idl(bn.to_flyte_idl()) assert bn == bn2 diff --git a/tests/flytekit/unit/models/test_common.py b/tests/flytekit/unit/models/test_common.py index 569b7a251a..eb41871eb0 100644 --- a/tests/flytekit/unit/models/test_common.py +++ b/tests/flytekit/unit/models/test_common.py @@ -5,29 +5,32 @@ def test_notification_email(): - obj = _common.EmailNotification(['a', 'b', 'c']) - assert obj.recipients_email == ['a', 'b', 'c'] + obj = _common.EmailNotification(["a", "b", "c"]) + assert obj.recipients_email == ["a", "b", "c"] obj2 = _common.EmailNotification.from_flyte_idl(obj.to_flyte_idl()) assert obj2 == obj def test_notification_pagerduty(): - obj = _common.PagerDutyNotification(['a', 'b', 'c']) - assert obj.recipients_email == ['a', 'b', 'c'] + obj = _common.PagerDutyNotification(["a", "b", "c"]) + assert obj.recipients_email == ["a", "b", "c"] obj2 = _common.PagerDutyNotification.from_flyte_idl(obj.to_flyte_idl()) assert obj2 == obj def test_notification_slack(): - obj = _common.SlackNotification(['a', 'b', 'c']) - assert obj.recipients_email == ['a', 'b', 'c'] + obj = _common.SlackNotification(["a", "b", "c"]) + assert obj.recipients_email == ["a", "b", "c"] obj2 = _common.SlackNotification.from_flyte_idl(obj.to_flyte_idl()) assert obj2 == obj def test_notification(): - phases = [_execution.WorkflowExecutionPhase.FAILED, _execution.WorkflowExecutionPhase.SUCCEEDED] - recipients = ['a', 'b', 'c'] + phases = [ + _execution.WorkflowExecutionPhase.FAILED, + _execution.WorkflowExecutionPhase.SUCCEEDED, + ] + recipients = ["a", "b", "c"] obj = _common.Notification(phases, email=_common.EmailNotification(recipients)) assert obj.phases == phases @@ -83,7 +86,7 @@ def test_auth_role(): def test_raw_output_data_config(): - obj = _common.RawOutputDataConfig('s3://bucket') - assert obj.output_location_prefix == 's3://bucket' + obj = _common.RawOutputDataConfig("s3://bucket") + assert obj.output_location_prefix == "s3://bucket" obj2 = _common.RawOutputDataConfig.from_flyte_idl(obj.to_flyte_idl()) assert obj2 == obj diff --git a/tests/flytekit/unit/models/test_dynamic_job.py b/tests/flytekit/unit/models/test_dynamic_job.py index 3d209f5769..146607d533 100644 --- a/tests/flytekit/unit/models/test_dynamic_job.py +++ b/tests/flytekit/unit/models/test_dynamic_job.py @@ -1,14 +1,17 @@ from __future__ import absolute_import +from datetime import timedelta as _timedelta from itertools import product import pytest -from datetime import timedelta as _timedelta from google.protobuf import text_format -from flytekit.models import literals as _literals, dynamic_job as _dynamic_job, array_job as _array_job, \ - task as _task -from flytekit.models.core import workflow as _workflow, identifier as _identifier +from flytekit.models import array_job as _array_job +from flytekit.models import dynamic_job as _dynamic_job +from flytekit.models import literals as _literals +from flytekit.models import task as _task +from flytekit.models.core import identifier as _identifier +from flytekit.models.core import workflow as _workflow from tests.flytekit.common import parameterizers LIST_OF_DYNAMIC_TASKS = [ @@ -19,34 +22,30 @@ interfaces, _array_job.ArrayJob(2, 2, 2).to_dict(), container=_task.Container( - "my_image", - ["this", "is", "a", "cmd"], - ["this", "is", "an", "arg"], - resources, - {'a': 'b'}, - {'d': 'e'} - ) + "my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], resources, {"a": "b"}, {"d": "e"}, + ), ) for task_metadata, interfaces, resources in product( - parameterizers.LIST_OF_TASK_METADATA, - parameterizers.LIST_OF_INTERFACES, - parameterizers.LIST_OF_RESOURCES + parameterizers.LIST_OF_TASK_METADATA, parameterizers.LIST_OF_INTERFACES, parameterizers.LIST_OF_RESOURCES, ) ] -@pytest.mark.parametrize("task", - LIST_OF_DYNAMIC_TASKS) +@pytest.mark.parametrize("task", LIST_OF_DYNAMIC_TASKS) def test_future_task_document(task): rs = _literals.RetryStrategy(0) - nm = _workflow.NodeMetadata('node-name', _timedelta(minutes=10), rs) - n = _workflow.Node(id="id", metadata=nm, inputs=[], upstream_node_ids=[], - output_aliases=[], task_node=_workflow.TaskNode(task.id)) + nm = _workflow.NodeMetadata("node-name", _timedelta(minutes=10), rs) + n = _workflow.Node( + id="id", metadata=nm, inputs=[], upstream_node_ids=[], output_aliases=[], task_node=_workflow.TaskNode(task.id), + ) n.to_flyte_idl() - doc = _dynamic_job.DynamicJobSpec(tasks=[task], - nodes=[n], - min_successes=1, - outputs=[_literals.Binding("var", _literals.BindingData())], - subworkflows=[]) + doc = _dynamic_job.DynamicJobSpec( + tasks=[task], + nodes=[n], + min_successes=1, + outputs=[_literals.Binding("var", _literals.BindingData())], + subworkflows=[], + ) assert text_format.MessageToString(doc.to_flyte_idl()) == text_format.MessageToString( - _dynamic_job.DynamicJobSpec.from_flyte_idl(doc.to_flyte_idl()).to_flyte_idl()) + _dynamic_job.DynamicJobSpec.from_flyte_idl(doc.to_flyte_idl()).to_flyte_idl() + ) diff --git a/tests/flytekit/unit/models/test_dynamic_wfs.py b/tests/flytekit/unit/models/test_dynamic_wfs.py index a1ca54fa6d..5a4ce948e3 100644 --- a/tests/flytekit/unit/models/test_dynamic_wfs.py +++ b/tests/flytekit/unit/models/test_dynamic_wfs.py @@ -1,9 +1,9 @@ -from __future__ import absolute_import -from __future__ import print_function +from __future__ import absolute_import, print_function -from flytekit.sdk import tasks as _tasks, workflow as _workflow -from flytekit.sdk.types import Types as _Types from flytekit.common import constants as _sdk_constants +from flytekit.sdk import tasks as _tasks +from flytekit.sdk import workflow as _workflow +from flytekit.sdk.types import Types as _Types @_tasks.inputs(num=_Types.Integer) @@ -124,4 +124,4 @@ def test_dynamic_launch_plan_yielding_of_input_only_workflow(): # map the LiteralMap of the inputs of that node input_key = "{}/inputs.pb".format(dj_spec.nodes[0].id) lp_input_map = outputs[input_key] - assert lp_input_map.literals['a'] is not None + assert lp_input_map.literals["a"] is not None diff --git a/tests/flytekit/unit/models/test_execution.py b/tests/flytekit/unit/models/test_execution.py index 1dea51e4d7..62feaf4691 100644 --- a/tests/flytekit/unit/models/test_execution.py +++ b/tests/flytekit/unit/models/test_execution.py @@ -1,40 +1,41 @@ from __future__ import absolute_import -from flytekit.models import execution as _execution, literals as _literal_models, common as _common_models -from flytekit.models.core import execution as _core_exec, identifier as _identifier -from tests.flytekit.common import parameterizers as _parameterizers + import pytest +from flytekit.models import common as _common_models +from flytekit.models import execution as _execution +from flytekit.models.core import execution as _core_exec +from flytekit.models.core import identifier as _identifier +from tests.flytekit.common import parameterizers as _parameterizers + def test_execution_metadata(): - obj = _execution.ExecutionMetadata(_execution.ExecutionMetadata.ExecutionMode.MANUAL, 'tester', 1) + obj = _execution.ExecutionMetadata(_execution.ExecutionMetadata.ExecutionMode.MANUAL, "tester", 1) assert obj.mode == _execution.ExecutionMetadata.ExecutionMode.MANUAL - assert obj.principal == 'tester' + assert obj.principal == "tester" assert obj.nesting == 1 obj2 = _execution.ExecutionMetadata.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 assert obj2.mode == _execution.ExecutionMetadata.ExecutionMode.MANUAL - assert obj2.principal == 'tester' + assert obj2.principal == "tester" assert obj2.nesting == 1 -@pytest.mark.parametrize( - "literal_value_pair", - _parameterizers.LIST_OF_SCALAR_LITERALS_AND_PYTHON_VALUE -) +@pytest.mark.parametrize("literal_value_pair", _parameterizers.LIST_OF_SCALAR_LITERALS_AND_PYTHON_VALUE) def test_execution_spec(literal_value_pair): literal_value, _ = literal_value_pair obj = _execution.ExecutionSpec( _identifier.Identifier(_identifier.ResourceType.LAUNCH_PLAN, "project", "domain", "name", "version"), - _execution.ExecutionMetadata(_execution.ExecutionMetadata.ExecutionMode.MANUAL, 'tester', 1), + _execution.ExecutionMetadata(_execution.ExecutionMetadata.ExecutionMode.MANUAL, "tester", 1), notifications=_execution.NotificationList( [ _common_models.Notification( [_core_exec.WorkflowExecutionPhase.ABORTED], - pager_duty=_common_models.PagerDutyNotification(recipients_email=['a', 'b', 'c']) + pager_duty=_common_models.PagerDutyNotification(recipients_email=["a", "b", "c"]), ) ] - ) + ), ) assert obj.launch_plan.resource_type == _identifier.ResourceType.LAUNCH_PLAN assert obj.launch_plan.domain == "domain" @@ -43,9 +44,13 @@ def test_execution_spec(literal_value_pair): assert obj.launch_plan.version == "version" assert obj.metadata.mode == _execution.ExecutionMetadata.ExecutionMode.MANUAL assert obj.metadata.nesting == 1 - assert obj.metadata.principal == 'tester' + assert obj.metadata.principal == "tester" assert obj.notifications.notifications[0].phases == [_core_exec.WorkflowExecutionPhase.ABORTED] - assert obj.notifications.notifications[0].pager_duty.recipients_email == ['a', 'b', 'c'] + assert obj.notifications.notifications[0].pager_duty.recipients_email == [ + "a", + "b", + "c", + ] assert obj.disable_all is None obj2 = _execution.ExecutionSpec.from_flyte_idl(obj.to_flyte_idl()) @@ -57,15 +62,19 @@ def test_execution_spec(literal_value_pair): assert obj2.launch_plan.version == "version" assert obj2.metadata.mode == _execution.ExecutionMetadata.ExecutionMode.MANUAL assert obj2.metadata.nesting == 1 - assert obj2.metadata.principal == 'tester' + assert obj2.metadata.principal == "tester" assert obj2.notifications.notifications[0].phases == [_core_exec.WorkflowExecutionPhase.ABORTED] - assert obj2.notifications.notifications[0].pager_duty.recipients_email == ['a', 'b', 'c'] + assert obj2.notifications.notifications[0].pager_duty.recipients_email == [ + "a", + "b", + "c", + ] assert obj2.disable_all is None obj = _execution.ExecutionSpec( _identifier.Identifier(_identifier.ResourceType.LAUNCH_PLAN, "project", "domain", "name", "version"), - _execution.ExecutionMetadata(_execution.ExecutionMetadata.ExecutionMode.MANUAL, 'tester', 1), - disable_all=True + _execution.ExecutionMetadata(_execution.ExecutionMetadata.ExecutionMode.MANUAL, "tester", 1), + disable_all=True, ) assert obj.launch_plan.resource_type == _identifier.ResourceType.LAUNCH_PLAN assert obj.launch_plan.domain == "domain" @@ -74,7 +83,7 @@ def test_execution_spec(literal_value_pair): assert obj.launch_plan.version == "version" assert obj.metadata.mode == _execution.ExecutionMetadata.ExecutionMode.MANUAL assert obj.metadata.nesting == 1 - assert obj.metadata.principal == 'tester' + assert obj.metadata.principal == "tester" assert obj.notifications is None assert obj.disable_all is True @@ -87,7 +96,7 @@ def test_execution_spec(literal_value_pair): assert obj2.launch_plan.version == "version" assert obj2.metadata.mode == _execution.ExecutionMetadata.ExecutionMode.MANUAL assert obj2.metadata.nesting == 1 - assert obj2.metadata.principal == 'tester' + assert obj2.metadata.principal == "tester" assert obj2.notifications is None assert obj2.disable_all is True diff --git a/tests/flytekit/unit/models/test_filters.py b/tests/flytekit/unit/models/test_filters.py index bc678afc46..5c4c432a9e 100644 --- a/tests/flytekit/unit/models/test_filters.py +++ b/tests/flytekit/unit/models/test_filters.py @@ -1,4 +1,5 @@ from __future__ import absolute_import + from flytekit.models import filters @@ -35,9 +36,6 @@ def test_contains_filter(): def test_filter_list(): - fl = filters.FilterList([ - filters.Equal("domain", "staging"), - filters.NotEqual("project", "FakeProject") - ]) + fl = filters.FilterList([filters.Equal("domain", "staging"), filters.NotEqual("project", "FakeProject")]) assert fl.to_flyte_idl() == "eq(domain,staging)+neq(project,FakeProject)" diff --git a/tests/flytekit/unit/models/test_interface.py b/tests/flytekit/unit/models/test_interface.py index 742d42f62f..316b970d83 100644 --- a/tests/flytekit/unit/models/test_interface.py +++ b/tests/flytekit/unit/models/test_interface.py @@ -17,19 +17,16 @@ def test_variable_type(literal_type): @pytest.mark.parametrize("literal_type", LIST_OF_ALL_LITERAL_TYPES) def test_typed_interface(literal_type): typed_interface = interface.TypedInterface( - {'a': interface.Variable(literal_type, "description1")}, - { - 'b': interface.Variable(literal_type, "description2"), - 'c': interface.Variable(literal_type, "description3") - } + {"a": interface.Variable(literal_type, "description1")}, + {"b": interface.Variable(literal_type, "description2"), "c": interface.Variable(literal_type, "description3")}, ) - assert typed_interface.inputs['a'].type == literal_type - assert typed_interface.outputs['b'].type == literal_type - assert typed_interface.outputs['c'].type == literal_type - assert typed_interface.inputs['a'].description == "description1" - assert typed_interface.outputs['b'].description == "description2" - assert typed_interface.outputs['c'].description == "description3" + assert typed_interface.inputs["a"].type == literal_type + assert typed_interface.outputs["b"].type == literal_type + assert typed_interface.outputs["c"].type == literal_type + assert typed_interface.inputs["a"].description == "description1" + assert typed_interface.outputs["b"].description == "description2" + assert typed_interface.outputs["c"].description == "description3" assert len(typed_interface.inputs) == 1 assert len(typed_interface.outputs) == 2 @@ -37,18 +34,18 @@ def test_typed_interface(literal_type): deserialized_typed_interface = interface.TypedInterface.from_flyte_idl(pb) assert typed_interface == deserialized_typed_interface - assert deserialized_typed_interface.inputs['a'].type == literal_type - assert deserialized_typed_interface.outputs['b'].type == literal_type - assert deserialized_typed_interface.outputs['c'].type == literal_type - assert deserialized_typed_interface.inputs['a'].description == "description1" - assert deserialized_typed_interface.outputs['b'].description == "description2" - assert deserialized_typed_interface.outputs['c'].description == "description3" + assert deserialized_typed_interface.inputs["a"].type == literal_type + assert deserialized_typed_interface.outputs["b"].type == literal_type + assert deserialized_typed_interface.outputs["c"].type == literal_type + assert deserialized_typed_interface.inputs["a"].description == "description1" + assert deserialized_typed_interface.outputs["b"].description == "description2" + assert deserialized_typed_interface.outputs["c"].description == "description3" assert len(deserialized_typed_interface.inputs) == 1 assert len(deserialized_typed_interface.outputs) == 2 def test_parameter(): - v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN), 'asdf asdf asdf') + v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN), "asdf asdf asdf") obj = interface.Parameter(var=v) assert obj.var == v @@ -58,17 +55,17 @@ def test_parameter(): def test_parameter_map(): - v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN), 'asdf asdf asdf') + v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN), "asdf asdf asdf") p = interface.Parameter(var=v) - obj = interface.ParameterMap({'ppp': p}) + obj = interface.ParameterMap({"ppp": p}) obj2 = interface.ParameterMap.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 def test_variable_map(): - v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN), 'asdf asdf asdf') - obj = interface.VariableMap({'vvv': v}) + v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN), "asdf asdf asdf") + obj = interface.VariableMap({"vvv": v}) obj2 = interface.VariableMap.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 diff --git a/tests/flytekit/unit/models/test_launch_plan.py b/tests/flytekit/unit/models/test_launch_plan.py index 5ce2ef1999..4617313d26 100644 --- a/tests/flytekit/unit/models/test_launch_plan.py +++ b/tests/flytekit/unit/models/test_launch_plan.py @@ -1,6 +1,6 @@ from __future__ import absolute_import -from flytekit.models import launch_plan, schedule, interface, types, literals, common +from flytekit.models import common, interface, launch_plan, literals, schedule, types from flytekit.models.core import identifier @@ -11,7 +11,7 @@ def test_metadata(): def test_metadata_schedule(): - s = schedule.Schedule('asdf', '1 3 4 5 6 7') + s = schedule.Schedule("asdf", "1 3 4 5 6 7") obj = launch_plan.LaunchPlanMetadata(schedule=s, notifications=[]) assert obj.schedule == s obj2 = launch_plan.LaunchPlanMetadata.from_flyte_idl(obj.to_flyte_idl()) @@ -20,13 +20,14 @@ def test_metadata_schedule(): def test_lp_closure(): - v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN), 'asdf asdf asdf') + v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN), "asdf asdf asdf") p = interface.Parameter(var=v) - parameter_map = interface.ParameterMap({'ppp': p}) + parameter_map = interface.ParameterMap({"ppp": p}) parameter_map.to_flyte_idl() - variable_map = interface.VariableMap({'vvv': v}) - obj = launch_plan.LaunchPlanClosure(state=launch_plan.LaunchPlanState.ACTIVE, expected_inputs=parameter_map, - expected_outputs=variable_map) + variable_map = interface.VariableMap({"vvv": v}) + obj = launch_plan.LaunchPlanClosure( + state=launch_plan.LaunchPlanState.ACTIVE, expected_inputs=parameter_map, expected_outputs=variable_map, + ) assert obj.expected_inputs == parameter_map assert obj.expected_outputs == variable_map @@ -39,36 +40,48 @@ def test_lp_closure(): def test_launch_plan_spec(): identifier_model = identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version") - s = schedule.Schedule('asdf', '1 3 4 5 6 7') + s = schedule.Schedule("asdf", "1 3 4 5 6 7") launch_plan_metadata_model = launch_plan.LaunchPlanMetadata(schedule=s, notifications=[]) - v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN), 'asdf asdf asdf') + v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN), "asdf asdf asdf") p = interface.Parameter(var=v) - parameter_map = interface.ParameterMap({'ppp': p}) + parameter_map = interface.ParameterMap({"ppp": p}) fixed_inputs = literals.LiteralMap( - { - 'a': literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=1))) - } + {"a": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=1)))} ) labels_model = common.Labels({}) annotations_model = common.Annotations({"my": "annotation"}) - auth_role_model = common.AuthRole(assumable_iam_role='my:iam:role') - raw_data_output_config = common.RawOutputDataConfig('s3://bucket') - empty_raw_data_output_config = common.RawOutputDataConfig('') - - lp_spec_raw_output_prefixed = launch_plan.LaunchPlanSpec(identifier_model, launch_plan_metadata_model, - parameter_map, fixed_inputs, labels_model, - annotations_model, auth_role_model, raw_data_output_config) + auth_role_model = common.AuthRole(assumable_iam_role="my:iam:role") + raw_data_output_config = common.RawOutputDataConfig("s3://bucket") + empty_raw_data_output_config = common.RawOutputDataConfig("") + + lp_spec_raw_output_prefixed = launch_plan.LaunchPlanSpec( + identifier_model, + launch_plan_metadata_model, + parameter_map, + fixed_inputs, + labels_model, + annotations_model, + auth_role_model, + raw_data_output_config, + ) obj2 = launch_plan.LaunchPlanSpec.from_flyte_idl(lp_spec_raw_output_prefixed.to_flyte_idl()) assert obj2 == lp_spec_raw_output_prefixed - lp_spec_no_prefix = launch_plan.LaunchPlanSpec(identifier_model, launch_plan_metadata_model, - parameter_map, fixed_inputs, labels_model, - annotations_model, auth_role_model, empty_raw_data_output_config) + lp_spec_no_prefix = launch_plan.LaunchPlanSpec( + identifier_model, + launch_plan_metadata_model, + parameter_map, + fixed_inputs, + labels_model, + annotations_model, + auth_role_model, + empty_raw_data_output_config, + ) obj2 = launch_plan.LaunchPlanSpec.from_flyte_idl(lp_spec_no_prefix.to_flyte_idl()) assert obj2 == lp_spec_no_prefix diff --git a/tests/flytekit/unit/models/test_literals.py b/tests/flytekit/unit/models/test_literals.py index 33218a155a..4790881888 100644 --- a/tests/flytekit/unit/models/test_literals.py +++ b/tests/flytekit/unit/models/test_literals.py @@ -1,10 +1,14 @@ from __future__ import absolute_import + from datetime import datetime, timedelta -from flytekit.models import literals, types as _types -from tests.flytekit.common import parameterizers + import pytest import pytz +from flytekit.models import literals +from flytekit.models import types as _types +from tests.flytekit.common import parameterizers + def test_retry_strategy(): obj = literals.RetryStrategy(3) @@ -37,7 +41,7 @@ def test_integer_primitive(): assert obj != literals.Primitive(datetime=datetime.now()) assert obj != literals.Primitive(duration=timedelta(minutes=1)) assert obj != literals.Primitive(float_value=1.0) - assert obj != literals.Primitive(string_value='abc') + assert obj != literals.Primitive(string_value="abc") obj2 = literals.Primitive.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 @@ -53,7 +57,7 @@ def test_integer_primitive(): assert obj2 != literals.Primitive(datetime=datetime.now()) assert obj2 != literals.Primitive(duration=timedelta(minutes=1)) assert obj2 != literals.Primitive(float_value=1.0) - assert obj2 != literals.Primitive(string_value='abc') + assert obj2 != literals.Primitive(string_value="abc") obj3 = literals.Primitive(integer=0) assert obj3.value == 0 @@ -76,7 +80,7 @@ def test_boolean_primitive(): assert obj != literals.Primitive(datetime=datetime.now()) assert obj != literals.Primitive(duration=timedelta(minutes=1)) assert obj != literals.Primitive(float_value=1.0) - assert obj != literals.Primitive(string_value='abc') + assert obj != literals.Primitive(string_value="abc") obj2 = literals.Primitive.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 @@ -92,7 +96,7 @@ def test_boolean_primitive(): assert obj2 != literals.Primitive(datetime=datetime.now()) assert obj2 != literals.Primitive(duration=timedelta(minutes=1)) assert obj2 != literals.Primitive(float_value=1.0) - assert obj2 != literals.Primitive(string_value='abc') + assert obj2 != literals.Primitive(string_value="abc") obj3 = literals.Primitive(boolean=False) assert obj3.value is False @@ -116,7 +120,7 @@ def test_datetime_primitive(): assert obj != literals.Primitive(datetime=dt + timedelta(seconds=1)) assert obj != literals.Primitive(duration=timedelta(minutes=1)) assert obj != literals.Primitive(float_value=1.0) - assert obj != literals.Primitive(string_value='abc') + assert obj != literals.Primitive(string_value="abc") obj2 = literals.Primitive.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 @@ -132,7 +136,7 @@ def test_datetime_primitive(): assert obj2 != literals.Primitive(datetime=dt + timedelta(seconds=1)) assert obj2 != literals.Primitive(duration=timedelta(minutes=1)) assert obj2 != literals.Primitive(float_value=1.0) - assert obj2 != literals.Primitive(string_value='abc') + assert obj2 != literals.Primitive(string_value="abc") with pytest.raises(Exception): literals.Primitive(datetime=1.0).to_flyte_idl() @@ -153,7 +157,7 @@ def test_duration_primitive(): assert obj != literals.Primitive(datetime=datetime.now()) assert obj != literals.Primitive(duration=timedelta(minutes=1)) assert obj != literals.Primitive(float_value=1.0) - assert obj != literals.Primitive(string_value='abc') + assert obj != literals.Primitive(string_value="abc") obj2 = literals.Primitive.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 @@ -169,7 +173,7 @@ def test_duration_primitive(): assert obj2 != literals.Primitive(datetime=datetime.now()) assert obj2 != literals.Primitive(duration=timedelta(minutes=1)) assert obj2 != literals.Primitive(float_value=1.0) - assert obj2 != literals.Primitive(string_value='abc') + assert obj2 != literals.Primitive(string_value="abc") with pytest.raises(Exception): literals.Primitive(duration=1.0).to_flyte_idl() @@ -189,7 +193,7 @@ def test_float_primitive(): assert obj != literals.Primitive(datetime=datetime.now()) assert obj != literals.Primitive(duration=timedelta(minutes=1)) assert obj != literals.Primitive(float_value=0.0) - assert obj != literals.Primitive(string_value='abc') + assert obj != literals.Primitive(string_value="abc") obj2 = literals.Primitive.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 @@ -205,30 +209,30 @@ def test_float_primitive(): assert obj2 != literals.Primitive(datetime=datetime.now()) assert obj2 != literals.Primitive(duration=timedelta(minutes=1)) assert obj2 != literals.Primitive(float_value=0.0) - assert obj2 != literals.Primitive(string_value='abc') + assert obj2 != literals.Primitive(string_value="abc") obj3 = literals.Primitive(float_value=0.0) assert obj3.value == 0.0 with pytest.raises(Exception): - literals.Primitive(float_value='abc').to_flyte_idl() + literals.Primitive(float_value="abc").to_flyte_idl() def test_string_primitive(): - obj = literals.Primitive(string_value='abc') + obj = literals.Primitive(string_value="abc") assert obj.integer is None assert obj.boolean is None assert obj.datetime is None assert obj.duration is None assert obj.float_value is None - assert obj.string_value == 'abc' - assert obj.value == 'abc' + assert obj.string_value == "abc" + assert obj.value == "abc" assert obj != literals.Primitive(integer=0) assert obj != literals.Primitive(boolean=False) assert obj != literals.Primitive(datetime=datetime.now()) assert obj != literals.Primitive(duration=timedelta(minutes=1)) assert obj != literals.Primitive(float_value=0.0) - assert obj != literals.Primitive(string_value='cba') + assert obj != literals.Primitive(string_value="cba") obj2 = literals.Primitive.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 @@ -237,14 +241,14 @@ def test_string_primitive(): assert obj2.datetime is None assert obj2.duration is None assert obj2.float_value is None - assert obj2.string_value == 'abc' - assert obj2.value == 'abc' + assert obj2.string_value == "abc" + assert obj2.value == "abc" assert obj2 != literals.Primitive(integer=0) assert obj2 != literals.Primitive(boolean=False) assert obj2 != literals.Primitive(datetime=datetime.now()) assert obj2 != literals.Primitive(duration=timedelta(minutes=1)) assert obj2 != literals.Primitive(float_value=0.0) - assert obj2 != literals.Primitive(string_value='bca') + assert obj2 != literals.Primitive(string_value="bca") obj3 = literals.Primitive(string_value="") assert obj3.value == "" @@ -299,12 +303,7 @@ def test_scalar_error(): def test_scalar_binary(): - obj = literals.Scalar( - binary=literals.Binary( - b"value", - "taggy" - ) - ) + obj = literals.Scalar(binary=literals.Binary(b"value", "taggy")) assert obj.primitive is None assert obj.error is None assert obj.blob is None @@ -327,14 +326,16 @@ def test_scalar_binary(): def test_scalar_schema(): - schema_type = _types.SchemaType([ - _types.SchemaType.SchemaColumn("a", _types.SchemaType.SchemaColumn.SchemaColumnType.INTEGER), - _types.SchemaType.SchemaColumn("b", _types.SchemaType.SchemaColumn.SchemaColumnType.FLOAT), - _types.SchemaType.SchemaColumn("c", _types.SchemaType.SchemaColumn.SchemaColumnType.STRING), - _types.SchemaType.SchemaColumn("d", _types.SchemaType.SchemaColumn.SchemaColumnType.DATETIME), - _types.SchemaType.SchemaColumn("e", _types.SchemaType.SchemaColumn.SchemaColumnType.DURATION), - _types.SchemaType.SchemaColumn("f", _types.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN) - ]) + schema_type = _types.SchemaType( + [ + _types.SchemaType.SchemaColumn("a", _types.SchemaType.SchemaColumn.SchemaColumnType.INTEGER), + _types.SchemaType.SchemaColumn("b", _types.SchemaType.SchemaColumn.SchemaColumnType.FLOAT), + _types.SchemaType.SchemaColumn("c", _types.SchemaType.SchemaColumn.SchemaColumnType.STRING), + _types.SchemaType.SchemaColumn("d", _types.SchemaType.SchemaColumn.SchemaColumnType.DATETIME), + _types.SchemaType.SchemaColumn("e", _types.SchemaType.SchemaColumn.SchemaColumnType.DURATION), + _types.SchemaType.SchemaColumn("f", _types.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN), + ] + ) schema = literals.Schema(uri="asdf", type=schema_type) obj = literals.Scalar(schema=schema) @@ -344,7 +345,7 @@ def test_scalar_schema(): assert obj.binary is None assert obj.schema is not None assert obj.none_type is None - assert obj.value.type.columns[0].name == 'a' + assert obj.value.type.columns[0].name == "a" assert len(obj.value.type.columns) == 6 obj2 = literals.Scalar.from_flyte_idl(obj.to_flyte_idl()) @@ -355,7 +356,7 @@ def test_scalar_schema(): assert obj2.binary is None assert obj2.schema is not None assert obj2.none_type is None - assert obj2.value.type.columns[0].name == 'a' + assert obj2.value.type.columns[0].name == "a" assert len(obj2.value.type.columns) == 6 @@ -378,33 +379,34 @@ def test_binding_data_map(): b1 = literals.BindingData(scalar=literals.Scalar(primitive=literals.Primitive(integer=5))) b2 = literals.BindingData(scalar=literals.Scalar(primitive=literals.Primitive(integer=57))) b3 = literals.BindingData(scalar=literals.Scalar(primitive=literals.Primitive(integer=2))) - binding_map_sub = literals.BindingDataMap(bindings={'first': b1, 'second': b2}) - binding_map = literals.BindingDataMap(bindings={'three': b3, - 'sample_map': literals.BindingData(map=binding_map_sub)}) + binding_map_sub = literals.BindingDataMap(bindings={"first": b1, "second": b2}) + binding_map = literals.BindingDataMap( + bindings={"three": b3, "sample_map": literals.BindingData(map=binding_map_sub)} + ) obj = literals.BindingData(map=binding_map) assert obj.scalar is None assert obj.promise is None assert obj.collection is None - assert obj.value.bindings['three'].value.value.value == 2 - assert obj.value.bindings['sample_map'].value.bindings['second'].value.value.value == 57 + assert obj.value.bindings["three"].value.value.value == 2 + assert obj.value.bindings["sample_map"].value.bindings["second"].value.value.value == 57 obj2 = literals.BindingData.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 assert obj2.scalar is None assert obj2.promise is None assert obj2.collection is None - assert obj2.value.bindings['three'].value.value.value == 2 - assert obj2.value.bindings['sample_map'].value.bindings['first'].value.value.value == 5 + assert obj2.value.bindings["three"].value.value.value == 2 + assert obj2.value.bindings["sample_map"].value.bindings["first"].value.value.value == 5 def test_binding_data_promise(): - obj = literals.BindingData(promise=_types.OutputReference('some_node', 'myvar')) + obj = literals.BindingData(promise=_types.OutputReference("some_node", "myvar")) assert obj.scalar is None assert obj.promise is not None assert obj.collection is None assert obj.map is None - assert obj.value.node_id == 'some_node' - assert obj.value.var == 'myvar' + assert obj.value.node_id == "some_node" + assert obj.value.var == "myvar" obj2 = literals.BindingData.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 diff --git a/tests/flytekit/unit/models/test_named_entity.py b/tests/flytekit/unit/models/test_named_entity.py index 1a9a107474..20168ba7a4 100644 --- a/tests/flytekit/unit/models/test_named_entity.py +++ b/tests/flytekit/unit/models/test_named_entity.py @@ -8,6 +8,7 @@ def test_identifier(): obj2 = named_entity.NamedEntityIdentifier.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 + def test_metadata(): obj = named_entity.NamedEntityMetadata("i am a description", named_entity.NamedEntityState.ACTIVE) obj2 = named_entity.NamedEntityMetadata.from_flyte_idl(obj.to_flyte_idl()) diff --git a/tests/flytekit/unit/models/test_qubole.py b/tests/flytekit/unit/models/test_qubole.py index cd006ecfd9..1421221fcf 100644 --- a/tests/flytekit/unit/models/test_qubole.py +++ b/tests/flytekit/unit/models/test_qubole.py @@ -1,20 +1,17 @@ from __future__ import absolute_import -import pytest - from flytekit.models import qubole -from tests.flytekit.common.parameterizers import LIST_OF_ALL_LITERAL_TYPES def test_hive_query(): - q = qubole.HiveQuery(query='some query', timeout_sec=10, retry_count=0) + q = qubole.HiveQuery(query="some query", timeout_sec=10, retry_count=0) q2 = qubole.HiveQuery.from_flyte_idl(q.to_flyte_idl()) assert q == q2 - assert q2.query == 'some query' + assert q2.query == "some query" def test_hive_job(): - query = qubole.HiveQuery(query='some query', timeout_sec=10, retry_count=0) - obj = qubole.QuboleHiveJob(query=query, cluster_label='default', tags=[]) + query = qubole.HiveQuery(query="some query", timeout_sec=10, retry_count=0) + obj = qubole.QuboleHiveJob(query=query, cluster_label="default", tags=[]) obj2 = qubole.QuboleHiveJob.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 diff --git a/tests/flytekit/unit/models/test_schedule.py b/tests/flytekit/unit/models/test_schedule.py index ff83eade1c..1889cd76c7 100644 --- a/tests/flytekit/unit/models/test_schedule.py +++ b/tests/flytekit/unit/models/test_schedule.py @@ -4,26 +4,26 @@ def test_schedule(): - obj = _schedule.Schedule(kickoff_time_input_arg='fdsa', cron_expression='1 2 3 4 5 6') + obj = _schedule.Schedule(kickoff_time_input_arg="fdsa", cron_expression="1 2 3 4 5 6") assert obj.rate is None - assert obj.cron_expression == '1 2 3 4 5 6' - assert obj.schedule_expression == '1 2 3 4 5 6' - assert obj.kickoff_time_input_arg == 'fdsa' + assert obj.cron_expression == "1 2 3 4 5 6" + assert obj.schedule_expression == "1 2 3 4 5 6" + assert obj.kickoff_time_input_arg == "fdsa" obj2 = _schedule.Schedule.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 assert obj2.rate is None - assert obj2.cron_expression == '1 2 3 4 5 6' - assert obj2.schedule_expression == '1 2 3 4 5 6' - assert obj2.kickoff_time_input_arg == 'fdsa' + assert obj2.cron_expression == "1 2 3 4 5 6" + assert obj2.schedule_expression == "1 2 3 4 5 6" + assert obj2.kickoff_time_input_arg == "fdsa" def test_schedule_fixed_rate(): fr = _schedule.Schedule.FixedRate(10, _schedule.Schedule.FixedRateUnit.MINUTE) - obj = _schedule.Schedule(kickoff_time_input_arg='fdsa', rate=fr) + obj = _schedule.Schedule(kickoff_time_input_arg="fdsa", rate=fr) assert obj.rate is not None assert obj.cron_expression is None - assert obj.kickoff_time_input_arg == 'fdsa' + assert obj.kickoff_time_input_arg == "fdsa" assert obj.rate == fr assert obj.schedule_expression == fr @@ -31,6 +31,6 @@ def test_schedule_fixed_rate(): assert obj == obj2 assert obj2.rate is not None assert obj2.cron_expression is None - assert obj2.kickoff_time_input_arg == 'fdsa' + assert obj2.kickoff_time_input_arg == "fdsa" assert obj2.rate == fr assert obj2.schedule_expression == fr diff --git a/tests/flytekit/unit/models/test_tasks.py b/tests/flytekit/unit/models/test_tasks.py index ac6694c2dd..31d4c84270 100644 --- a/tests/flytekit/unit/models/test_tasks.py +++ b/tests/flytekit/unit/models/test_tasks.py @@ -1,14 +1,15 @@ from __future__ import absolute_import -import pytest from datetime import timedelta -from google.protobuf import text_format from itertools import product +import pytest from flyteidl.core.tasks_pb2 import TaskMetadata -from flytekit.models import task, literals -from flytekit.models.core import identifier +from google.protobuf import text_format from k8s.io.api.core.v1 import generated_pb2 + +from flytekit.models import literals, task +from flytekit.models.core import identifier from tests.flytekit.common import parameterizers @@ -47,17 +48,18 @@ def test_task_metadata_interruptible_from_flyte_idl(): # Interruptible not set idl = TaskMetadata() obj = task.TaskMetadata.from_flyte_idl(idl) - assert obj.interruptible == None + assert obj.interruptible is None idl = TaskMetadata() idl.interruptible = True obj = task.TaskMetadata.from_flyte_idl(idl) - assert obj.interruptible == True + assert obj.interruptible is True idl = TaskMetadata() idl.interruptible = False obj = task.TaskMetadata.from_flyte_idl(idl) - assert obj.interruptible == False + assert obj.interruptible is False + def test_task_metadata(): obj = task.TaskMetadata( @@ -67,7 +69,7 @@ def test_task_metadata(): literals.RetryStrategy(3), True, "0.1.1b0", - "This is deprecated!" + "This is deprecated!", ) assert obj.discoverable is True @@ -84,11 +86,7 @@ def test_task_metadata(): @pytest.mark.parametrize( "in_tuple", - product( - parameterizers.LIST_OF_TASK_METADATA, - parameterizers.LIST_OF_INTERFACES, - parameterizers.LIST_OF_RESOURCES - ) + product(parameterizers.LIST_OF_TASK_METADATA, parameterizers.LIST_OF_INTERFACES, parameterizers.LIST_OF_RESOURCES,), ) def test_task_template(in_tuple): task_metadata, interfaces, resources = in_tuple @@ -97,15 +95,10 @@ def test_task_template(in_tuple): "python", task_metadata, interfaces, - {'a': 1, 'b': {'c': 2, 'd': 3}}, + {"a": 1, "b": {"c": 2, "d": 3}}, container=task.Container( - "my_image", - ["this", "is", "a", "cmd"], - ["this", "is", "an", "arg"], - resources, - {'a': 'b'}, - {'d': 'e'} - ) + "my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], resources, {"a": "b"}, {"d": "e"}, + ), ) assert obj.id.resource_type == identifier.ResourceType.TASK assert obj.id.project == "project" @@ -115,18 +108,18 @@ def test_task_template(in_tuple): assert obj.type == "python" assert obj.metadata == task_metadata assert obj.interface == interfaces - assert obj.custom == {'a': 1, 'b': {'c': 2, 'd': 3}} + assert obj.custom == {"a": 1, "b": {"c": 2, "d": 3}} assert obj.container.image == "my_image" assert obj.container.resources == resources assert text_format.MessageToString(obj.to_flyte_idl()) == text_format.MessageToString( - task.TaskTemplate.from_flyte_idl(obj.to_flyte_idl()).to_flyte_idl()) + task.TaskTemplate.from_flyte_idl(obj.to_flyte_idl()).to_flyte_idl() + ) @pytest.mark.parametrize("task_closure", parameterizers.LIST_OF_TASK_CLOSURES) def test_task(task_closure): obj = task.Task( - identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version"), - task_closure + identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version"), task_closure, ) assert obj.id.project == "project" assert obj.id.domain == "domain" @@ -139,19 +132,14 @@ def test_task(task_closure): @pytest.mark.parametrize("resources", parameterizers.LIST_OF_RESOURCES) def test_container(resources): obj = task.Container( - "my_image", - ["this", "is", "a", "cmd"], - ["this", "is", "an", "arg"], - resources, - {'a': 'b'}, - {'d': 'e'} + "my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], resources, {"a": "b"}, {"d": "e"}, ) obj.image == "my_image" obj.command == ["this", "is", "a", "cmd"] obj.args == ["this", "is", "an", "arg"] obj.resources == resources - obj.env == {'a': 'b'} - obj.config == {'d': 'e'} + obj.env == {"a": "b"} + obj.config == {"d": "e"} assert obj == task.Container.from_flyte_idl(obj.to_flyte_idl()) @@ -169,13 +157,19 @@ def test_sidecar_task(): def test_dataloadingconfig(): - dlc = task.DataLoadingConfig("s3://input/path", "s3://output/path", True, - task.DataLoadingConfig.LITERALMAP_FORMAT_YAML) + dlc = task.DataLoadingConfig( + "s3://input/path", "s3://output/path", True, task.DataLoadingConfig.LITERALMAP_FORMAT_YAML, + ) dlc2 = task.DataLoadingConfig.from_flyte_idl(dlc.to_flyte_idl()) assert dlc2 == dlc - dlc = task.DataLoadingConfig("s3://input/path", "s3://output/path", True, - task.DataLoadingConfig.LITERALMAP_FORMAT_YAML, io_strategy=task.IOStrategy()) + dlc = task.DataLoadingConfig( + "s3://input/path", + "s3://output/path", + True, + task.DataLoadingConfig.LITERALMAP_FORMAT_YAML, + io_strategy=task.IOStrategy(), + ) dlc2 = task.DataLoadingConfig.from_flyte_idl(dlc.to_flyte_idl()) assert dlc2 == dlc diff --git a/tests/flytekit/unit/models/test_types.py b/tests/flytekit/unit/models/test_types.py index ad410479e6..dd539db528 100644 --- a/tests/flytekit/unit/models/test_types.py +++ b/tests/flytekit/unit/models/test_types.py @@ -29,14 +29,16 @@ def test_schema_column(): def test_schema_type(): - obj = _types.SchemaType([ - _types.SchemaType.SchemaColumn("a", _types.SchemaType.SchemaColumn.SchemaColumnType.INTEGER), - _types.SchemaType.SchemaColumn("b", _types.SchemaType.SchemaColumn.SchemaColumnType.FLOAT), - _types.SchemaType.SchemaColumn("c", _types.SchemaType.SchemaColumn.SchemaColumnType.STRING), - _types.SchemaType.SchemaColumn("d", _types.SchemaType.SchemaColumn.SchemaColumnType.DATETIME), - _types.SchemaType.SchemaColumn("e", _types.SchemaType.SchemaColumn.SchemaColumnType.DURATION), - _types.SchemaType.SchemaColumn("f", _types.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN) - ]) + obj = _types.SchemaType( + [ + _types.SchemaType.SchemaColumn("a", _types.SchemaType.SchemaColumn.SchemaColumnType.INTEGER), + _types.SchemaType.SchemaColumn("b", _types.SchemaType.SchemaColumn.SchemaColumnType.FLOAT), + _types.SchemaType.SchemaColumn("c", _types.SchemaType.SchemaColumn.SchemaColumnType.STRING), + _types.SchemaType.SchemaColumn("d", _types.SchemaType.SchemaColumn.SchemaColumnType.DATETIME), + _types.SchemaType.SchemaColumn("e", _types.SchemaType.SchemaColumn.SchemaColumnType.DURATION), + _types.SchemaType.SchemaColumn("f", _types.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN), + ] + ) assert obj.columns[0].name == "a" assert obj.columns[1].name == "b" @@ -63,14 +65,16 @@ def test_literal_types(): assert obj.map_value_type is None assert obj == _types.LiteralType.from_flyte_idl(obj.to_flyte_idl()) - schema_type = _types.SchemaType([ - _types.SchemaType.SchemaColumn("a", _types.SchemaType.SchemaColumn.SchemaColumnType.INTEGER), - _types.SchemaType.SchemaColumn("b", _types.SchemaType.SchemaColumn.SchemaColumnType.FLOAT), - _types.SchemaType.SchemaColumn("c", _types.SchemaType.SchemaColumn.SchemaColumnType.STRING), - _types.SchemaType.SchemaColumn("d", _types.SchemaType.SchemaColumn.SchemaColumnType.DATETIME), - _types.SchemaType.SchemaColumn("e", _types.SchemaType.SchemaColumn.SchemaColumnType.DURATION), - _types.SchemaType.SchemaColumn("f", _types.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN) - ]) + schema_type = _types.SchemaType( + [ + _types.SchemaType.SchemaColumn("a", _types.SchemaType.SchemaColumn.SchemaColumnType.INTEGER), + _types.SchemaType.SchemaColumn("b", _types.SchemaType.SchemaColumn.SchemaColumnType.FLOAT), + _types.SchemaType.SchemaColumn("c", _types.SchemaType.SchemaColumn.SchemaColumnType.STRING), + _types.SchemaType.SchemaColumn("d", _types.SchemaType.SchemaColumn.SchemaColumnType.DATETIME), + _types.SchemaType.SchemaColumn("e", _types.SchemaType.SchemaColumn.SchemaColumnType.DURATION), + _types.SchemaType.SchemaColumn("f", _types.SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN), + ] + ) obj = _types.LiteralType(schema=schema_type) assert obj.simple is None assert obj.schema == schema_type @@ -90,8 +94,8 @@ def test_literal_collections(literal_type): def test_output_reference(): - obj = _types.OutputReference(node_id='node1', var='var1') - assert obj.node_id == 'node1' - assert obj.var == 'var1' + obj = _types.OutputReference(node_id="node1", var="var1") + assert obj.node_id == "node1" + assert obj.var == "var1" obj2 = _types.OutputReference.from_flyte_idl(obj.to_flyte_idl()) assert obj == obj2 diff --git a/tests/flytekit/unit/models/test_workflow_closure.py b/tests/flytekit/unit/models/test_workflow_closure.py index a665268a37..0a2a6e09b2 100644 --- a/tests/flytekit/unit/models/test_workflow_closure.py +++ b/tests/flytekit/unit/models/test_workflow_closure.py @@ -2,32 +2,30 @@ from datetime import timedelta -from flytekit.models import workflow_closure as _workflow_closure, interface as _interface, \ - literals as _literals, types as _types, task as _task -from flytekit.models.core import workflow as _workflow, identifier as _identifier +from flytekit.models import interface as _interface +from flytekit.models import literals as _literals +from flytekit.models import task as _task +from flytekit.models import types as _types +from flytekit.models import workflow_closure as _workflow_closure +from flytekit.models.core import identifier as _identifier +from flytekit.models.core import workflow as _workflow def test_workflow_closure(): int_type = _types.LiteralType(_types.SimpleType.INTEGER) typed_interface = _interface.TypedInterface( - {'a': _interface.Variable(int_type, "description1")}, - { - 'b': _interface.Variable(int_type, "description2"), - 'c': _interface.Variable(int_type, "description3") - } + {"a": _interface.Variable(int_type, "description1")}, + {"b": _interface.Variable(int_type, "description2"), "c": _interface.Variable(int_type, "description3")}, ) - b0 = _literals.Binding('a', _literals.BindingData( - scalar=_literals.Scalar(primitive=_literals.Primitive(integer=5)))) - b1 = _literals.Binding('b', _literals.BindingData( - promise=_types.OutputReference('my_node', 'b'))) - b2 = _literals.Binding('c', _literals.BindingData( - promise=_types.OutputReference('my_node', 'c'))) + b0 = _literals.Binding( + "a", _literals.BindingData(scalar=_literals.Scalar(primitive=_literals.Primitive(integer=5))), + ) + b1 = _literals.Binding("b", _literals.BindingData(promise=_types.OutputReference("my_node", "b"))) + b2 = _literals.Binding("c", _literals.BindingData(promise=_types.OutputReference("my_node", "c"))) node_metadata = _workflow.NodeMetadata( - name='node1', - timeout=timedelta(seconds=10), - retries=_literals.RetryStrategy(0) + name="node1", timeout=timedelta(seconds=10), retries=_literals.RetryStrategy(0) ) task_metadata = _task.TaskMetadata( @@ -37,7 +35,7 @@ def test_workflow_closure(): _literals.RetryStrategy(3), True, "0.1.1b0", - "This is deprecated!" + "This is deprecated!", ) cpu_resource = _task.Resources.ResourceEntry(_task.Resources.ResourceName.CPU, "1") @@ -48,25 +46,16 @@ def test_workflow_closure(): "python", task_metadata, typed_interface, - {'a': 1, 'b': {'c': 2, 'd': 3}}, + {"a": 1, "b": {"c": 2, "d": 3}}, container=_task.Container( - "my_image", - ["this", "is", "a", "cmd"], - ["this", "is", "an", "arg"], - resources, - {}, - {} - ) + "my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], resources, {}, {}, + ), ) task_node = _workflow.TaskNode(task.id) node = _workflow.Node( - id='my_node', - metadata=node_metadata, - inputs=[b0], - upstream_node_ids=[], - output_aliases=[], - task_node=task_node) + id="my_node", metadata=node_metadata, inputs=[b0], upstream_node_ids=[], output_aliases=[], task_node=task_node, + ) template = _workflow.WorkflowTemplate( id=_identifier.Identifier(_identifier.ResourceType.WORKFLOW, "project", "domain", "name", "version"), diff --git a/tests/flytekit/unit/sdk/conftest.py b/tests/flytekit/unit/sdk/conftest.py index 15d5e13e8c..0be301949f 100644 --- a/tests/flytekit/unit/sdk/conftest.py +++ b/tests/flytekit/unit/sdk/conftest.py @@ -1,9 +1,11 @@ from __future__ import absolute_import -from flytekit.configuration import TemporaryConfiguration + import pytest as _pytest +from flytekit.configuration import TemporaryConfiguration + -@_pytest.fixture(scope='function', autouse=True) +@_pytest.fixture(scope="function", autouse=True) def set_fake_config(): - with TemporaryConfiguration(None, internal_overrides={'image': 'fakeimage'}): + with TemporaryConfiguration(None, internal_overrides={"image": "fakeimage"}): yield diff --git a/tests/flytekit/unit/sdk/tasks/test_dynamic_sidecar_tasks.py b/tests/flytekit/unit/sdk/tasks/test_dynamic_sidecar_tasks.py index 9ab99645b8..be903a9c82 100644 --- a/tests/flytekit/unit/sdk/tasks/test_dynamic_sidecar_tasks.py +++ b/tests/flytekit/unit/sdk/tasks/test_dynamic_sidecar_tasks.py @@ -1,32 +1,28 @@ from __future__ import absolute_import import mock -from flytekit.configuration.internal import IMAGE as _IMAGE - -from flytekit.common.tasks import sdk_dynamic as _sdk_dynamic, sdk_runnable as _sdk_runnable, sidecar_task as _sidecar_task +from k8s.io.api.core.v1 import generated_pb2 -from flytekit.sdk.tasks import inputs, outputs, dynamic_sidecar_task, python_task +from flytekit.common.tasks import sdk_dynamic as _sdk_dynamic +from flytekit.common.tasks import sdk_runnable as _sdk_runnable +from flytekit.common.tasks import sidecar_task as _sidecar_task +from flytekit.configuration.internal import IMAGE as _IMAGE +from flytekit.sdk.tasks import dynamic_sidecar_task, inputs, outputs, python_task from flytekit.sdk.types import Types -from k8s.io.api.core.v1 import generated_pb2 - def get_pod_spec(): a_container = generated_pb2.Container(name="main") a_container.command.extend(["foo", "bar"]) - a_container.volumeMounts.extend([generated_pb2.VolumeMount( - name="scratch", - mountPath="/scratch", - )]) + a_container.volumeMounts.extend([generated_pb2.VolumeMount(name="scratch", mountPath="/scratch",)]) - pod_spec = generated_pb2.PodSpec( - restartPolicy="Never", - ) + pod_spec = generated_pb2.PodSpec(restartPolicy="Never",) pod_spec.containers.extend([a_container, generated_pb2.Container(name="sidecar")]) return pod_spec -with mock.patch.object(_IMAGE, 'get', return_value='docker.io/blah:abc123'): +with mock.patch.object(_IMAGE, "get", return_value="docker.io/blah:abc123"): + @outputs(out1=Types.String) @python_task def simple_python_task(wf_params, out1): @@ -35,8 +31,8 @@ def simple_python_task(wf_params, out1): @inputs(in1=Types.Integer) @outputs(out1=Types.String) @dynamic_sidecar_task( - cpu_request='10', - memory_limit='2Gi', + cpu_request="10", + memory_limit="2Gi", environment={"foo": "bar"}, pod_spec=get_pod_spec(), primary_container_name="main", @@ -51,18 +47,27 @@ def test_dynamic_sidecar_task(): assert isinstance(simple_dynamic_sidecar_task, _sidecar_task.SdkSidecarTask) assert isinstance(simple_dynamic_sidecar_task, _sdk_dynamic.SdkDynamicTaskMixin) - pod_spec = simple_dynamic_sidecar_task.custom['podSpec'] - assert pod_spec['restartPolicy'] == 'Never' - assert len(pod_spec['containers']) == 2 - primary_container = pod_spec['containers'][0] - assert primary_container['name'] == 'main' - assert primary_container['args'] == ['pyflyte-execute', '--task-module', - 'tests.flytekit.unit.sdk.tasks.test_dynamic_sidecar_tasks', '--task-name', - 'simple_dynamic_sidecar_task', '--inputs', '{{.input}}', '--output-prefix', - '{{.outputPrefix}}'] - assert primary_container['volumeMounts'] == [{'mountPath': '/scratch', 'name': 'scratch'}] - assert {'name': 'foo', 'value': 'bar'} in primary_container['env'] - assert primary_container['resources'] == {'requests': {'cpu': {'string': '10'}}, - 'limits': {'memory': {'string': '2Gi'}}} - assert pod_spec['containers'][1]['name'] == 'sidecar' - assert simple_dynamic_sidecar_task.custom['primaryContainerName'] == 'main' + pod_spec = simple_dynamic_sidecar_task.custom["podSpec"] + assert pod_spec["restartPolicy"] == "Never" + assert len(pod_spec["containers"]) == 2 + primary_container = pod_spec["containers"][0] + assert primary_container["name"] == "main" + assert primary_container["args"] == [ + "pyflyte-execute", + "--task-module", + "tests.flytekit.unit.sdk.tasks.test_dynamic_sidecar_tasks", + "--task-name", + "simple_dynamic_sidecar_task", + "--inputs", + "{{.input}}", + "--output-prefix", + "{{.outputPrefix}}", + ] + assert primary_container["volumeMounts"] == [{"mountPath": "/scratch", "name": "scratch"}] + assert {"name": "foo", "value": "bar"} in primary_container["env"] + assert primary_container["resources"] == { + "requests": {"cpu": {"string": "10"}}, + "limits": {"memory": {"string": "2Gi"}}, + } + assert pod_spec["containers"][1]["name"] == "sidecar" + assert simple_dynamic_sidecar_task.custom["primaryContainerName"] == "main" diff --git a/tests/flytekit/unit/sdk/tasks/test_dynamic_tasks.py b/tests/flytekit/unit/sdk/tasks/test_dynamic_tasks.py index b947fd40fe..4fcabc1877 100644 --- a/tests/flytekit/unit/sdk/tasks/test_dynamic_tasks.py +++ b/tests/flytekit/unit/sdk/tasks/test_dynamic_tasks.py @@ -1,12 +1,12 @@ -from __future__ import absolute_import -from __future__ import print_function +from __future__ import absolute_import, print_function from six import moves as _six_moves -from flytekit.common.tasks import sdk_runnable as _sdk_runnable, sdk_dynamic as _sdk_dynamic -from flytekit.sdk.tasks import inputs, outputs, dynamic_task, python_task +from flytekit.common.tasks import sdk_dynamic as _sdk_dynamic +from flytekit.common.tasks import sdk_runnable as _sdk_runnable +from flytekit.sdk.tasks import dynamic_task, inputs, outputs, python_task from flytekit.sdk.types import Types -from flytekit.sdk.workflow import Input, workflow, Output +from flytekit.sdk.workflow import Input, Output, workflow @inputs(in1=Types.Integer) @@ -124,20 +124,13 @@ def dynamic_wf_task(wf_params, task_input_num, out): node1 = sq_sub_task(in1=input_a) MyUnregisteredWorkflow = workflow( - inputs={ - 'a': input_a, - }, - outputs={ - 'ooo': Output(node1.outputs.out1, sdk_type=Types.Integer, - help='This is an integer output') - }, - nodes={ - 'node_one': node1, - } + inputs={"a": input_a}, + outputs={"ooo": Output(node1.outputs.out1, sdk_type=Types.Integer, help="This is an integer output")}, + nodes={"node_one": node1}, ) - setattr(MyUnregisteredWorkflow, 'auto_assign_name', manual_assign_name) - MyUnregisteredWorkflow._platform_valid_name = 'unregistered' + setattr(MyUnregisteredWorkflow, "auto_assign_name", manual_assign_name) + MyUnregisteredWorkflow._platform_valid_name = "unregistered" unregistered_workflow_execution = MyUnregisteredWorkflow(a=task_input_num) out.set(unregistered_workflow_execution.outputs.ooo) @@ -149,10 +142,17 @@ def test_batch_task(): assert isinstance(sample_batch_task, _sdk_dynamic.SdkDynamicTaskMixin) expected = { - 'out_str': ["I'm the first result", 'hello 0', "I'm after each sub-task result", 'hello 1', - "I'm after each sub-task result", 'hello 2', "I'm after each sub-task result", - "I'm the last result"], - 'out_ints': [[0, 0, 0], [1, 2, 3], [2, 4, 6], [0, 1, 4], [0, 1, 4]] + "out_str": [ + "I'm the first result", + "hello 0", + "I'm after each sub-task result", + "hello 1", + "I'm after each sub-task result", + "hello 2", + "I'm after each sub-task result", + "I'm the last result", + ], + "out_ints": [[0, 0, 0], [1, 2, 3], [2, 4, 6], [0, 1, 4], [0, 1, 4]], } res = sample_batch_task.unit_test(in1=3) @@ -160,9 +160,7 @@ def test_batch_task(): def test_no_future_batch_task(): - expected = { - 'out_str': ["res1", "res2"] - } + expected = {"out_str": ["res1", "res2"]} res = no_future_batch_task.unit_test(in1=3) assert expected == res @@ -187,40 +185,26 @@ def nested_dynamic_wf_task(wf_params, task_input_num, out): node1 = sq_sub_task(in1=input_a) MyUnregisteredWorkflowInner = workflow( - inputs={ - 'a': input_a, - }, - outputs={ - 'ooo': Output(node1.outputs.out1, sdk_type=Types.Integer, - help='This is an integer output') - }, - nodes={ - 'node_one': node1, - } + inputs={"a": input_a}, + outputs={"ooo": Output(node1.outputs.out1, sdk_type=Types.Integer, help="This is an integer output")}, + nodes={"node_one": node1}, ) - setattr(MyUnregisteredWorkflowInner, 'auto_assign_name', manual_assign_name) - MyUnregisteredWorkflowInner._platform_valid_name = 'unregistered' + setattr(MyUnregisteredWorkflowInner, "auto_assign_name", manual_assign_name) + MyUnregisteredWorkflowInner._platform_valid_name = "unregistered" # Output workflow input_a = Input(Types.Integer, help="Tell me something") node1 = MyUnregisteredWorkflowInner(a=task_input_num) MyUnregisteredWorkflowOuter = workflow( - inputs={ - 'a': input_a, - }, - outputs={ - 'ooo': Output(node1.outputs.ooo, sdk_type=Types.Integer, - help='This is an integer output') - }, - nodes={ - 'node_one': node1, - } + inputs={"a": input_a}, + outputs={"ooo": Output(node1.outputs.ooo, sdk_type=Types.Integer, help="This is an integer output")}, + nodes={"node_one": node1}, ) - setattr(MyUnregisteredWorkflowOuter, 'auto_assign_name', manual_assign_name) - MyUnregisteredWorkflowOuter._platform_valid_name = 'unregistered' + setattr(MyUnregisteredWorkflowOuter, "auto_assign_name", manual_assign_name) + MyUnregisteredWorkflowOuter._platform_valid_name = "unregistered" unregistered_workflow_execution = MyUnregisteredWorkflowOuter(a=task_input_num) out.set(unregistered_workflow_execution.outputs.ooo) @@ -242,18 +226,10 @@ def dynamic_wf_no_outputs_task(wf_params, task_input_num): input_a = Input(Types.Integer, help="Tell me something") node1 = sq_sub_task(in1=input_a) - MyUnregisteredWorkflow = workflow( - inputs={ - 'a': input_a, - }, - outputs={}, - nodes={ - 'node_one': node1, - } - ) + MyUnregisteredWorkflow = workflow(inputs={"a": input_a}, outputs={}, nodes={"node_one": node1}) - setattr(MyUnregisteredWorkflow, 'auto_assign_name', manual_assign_name) - MyUnregisteredWorkflow._platform_valid_name = 'unregistered' + setattr(MyUnregisteredWorkflow, "auto_assign_name", manual_assign_name) + MyUnregisteredWorkflow._platform_valid_name = "unregistered" unregistered_workflow_execution = MyUnregisteredWorkflow(a=task_input_num) yield unregistered_workflow_execution diff --git a/tests/flytekit/unit/sdk/tasks/test_hive_tasks.py b/tests/flytekit/unit/sdk/tasks/test_hive_tasks.py index 9b3ded81e2..48475ca18e 100644 --- a/tests/flytekit/unit/sdk/tasks/test_hive_tasks.py +++ b/tests/flytekit/unit/sdk/tasks/test_hive_tasks.py @@ -1,32 +1,35 @@ -from __future__ import absolute_import -from __future__ import print_function +from __future__ import absolute_import, print_function -from datetime import datetime as _datetime import logging as _logging +from datetime import datetime as _datetime + import six as _six from flytekit.common import utils as _common_utils +from flytekit.common.tasks import hive_task as _hive_task from flytekit.common.tasks import output as _task_output -from flytekit.common.tasks import sdk_runnable as _sdk_runnable, hive_task as _hive_task -from flytekit.common.types import base_sdk_types as _base_sdk_types, containers as _containers, schema as _schema -from flytekit.models import literals as _literals -from flytekit.models.core.identifier import WorkflowExecutionIdentifier +from flytekit.common.tasks import sdk_runnable as _sdk_runnable +from flytekit.common.types import base_sdk_types as _base_sdk_types +from flytekit.common.types import containers as _containers from flytekit.common.types import helpers as _type_helpers +from flytekit.common.types import schema as _schema from flytekit.common.types.impl.schema import Schema from flytekit.engines import common as _common_engine -from flytekit.sdk.tasks import inputs, hive_task, qubole_hive_task, outputs +from flytekit.models import literals as _literals +from flytekit.models.core.identifier import WorkflowExecutionIdentifier +from flytekit.sdk.tasks import hive_task, inputs, outputs, qubole_hive_task from flytekit.sdk.types import Types -@hive_task(cache_version='1') +@hive_task(cache_version="1") def sample_hive_task_no_input(wf_params): - return _six.text_type('select 5') + return _six.text_type("select 5") @inputs(in1=Types.Integer) -@hive_task(cache_version='1') +@hive_task(cache_version="1") def sample_hive_task(wf_params, in1): - return _six.text_type('select ') + _six.text_type(in1) + return _six.text_type("select ") + _six.text_type(in1) @hive_task @@ -34,21 +37,19 @@ def sample_hive_task_no_queries(wf_params): return [] -@qubole_hive_task(cache_version='1', - cluster_label=_six.text_type('cluster_label'), - tags=[], - ) +@qubole_hive_task( + cache_version="1", cluster_label=_six.text_type("cluster_label"), tags=[], +) def sample_qubole_hive_task_no_input(wf_params): - return _six.text_type('select 5') + return _six.text_type("select 5") @inputs(in1=Types.Integer) -@qubole_hive_task(cache_version='1', - cluster_label=_six.text_type('cluster_label'), - tags=[_six.text_type('tag1')], - ) +@qubole_hive_task( + cache_version="1", cluster_label=_six.text_type("cluster_label"), tags=[_six.text_type("tag1")], +) def sample_qubole_hive_task(wf_params, in1): - return _six.text_type('select ') + _six.text_type(in1) + return _six.text_type("select ") + _six.text_type(in1) def test_hive_task(): @@ -72,9 +73,9 @@ def two_queries(wf_params, hive_results): def test_interface_setup(): outs = two_queries.interface.outputs - assert outs['hive_results'].type.collection_type is not None - assert outs['hive_results'].type.collection_type.schema is not None - assert outs['hive_results'].type.collection_type.schema.columns == [] + assert outs["hive_results"].type.collection_type is not None + assert outs["hive_results"].type.collection_type.schema is not None + assert outs["hive_results"].type.collection_type.schema.columns == [] def test_sdk_output_references_construction(): @@ -83,25 +84,21 @@ def test_sdk_output_references_construction(): for name, variable in _six.iteritems(two_queries.interface.outputs) } # Before user code is run, the outputs passed to the user code should not have values - assert references['hive_results'].sdk_value == _base_sdk_types.Void() + assert references["hive_results"].sdk_value == _base_sdk_types.Void() # Should be a list of schemas - assert isinstance(references['hive_results'].sdk_type, _containers.TypedCollectionType) - assert isinstance(references['hive_results'].sdk_type.sub_type, _schema.SchemaInstantiator) + assert isinstance(references["hive_results"].sdk_type, _containers.TypedCollectionType) + assert isinstance(references["hive_results"].sdk_type.sub_type, _schema.SchemaInstantiator) def test_hive_task_query_generation(): with _common_utils.AutoDeletingTempDir("user_dir") as user_working_directory: context = _common_engine.EngineContext( - execution_id=WorkflowExecutionIdentifier( - project='unit_test', - domain='unit_test', - name='unit_test' - ), + execution_id=WorkflowExecutionIdentifier(project="unit_test", domain="unit_test", name="unit_test"), execution_date=_datetime.utcnow(), stats=None, # TODO: A mock stats object that we can read later. logging=_logging, # TODO: A mock logging object that we can read later. - tmp_dir=user_working_directory + tmp_dir=user_working_directory, ) references = { name: _task_output.OutputReference(_type_helpers.get_sdk_type_from_literal_type(variable.type)) @@ -116,24 +113,20 @@ def test_hive_task_query_generation(): assert len(qubole_hive_jobs[1].query_collection.queries) == 1 # The output references should now have the same fake S3 path as the formatted queries - assert references['hive_results'].value[0].uri != '' - assert references['hive_results'].value[1].uri != '' - assert references['hive_results'].value[0].uri in qubole_hive_jobs[0].query.query - assert references['hive_results'].value[1].uri in qubole_hive_jobs[1].query.query + assert references["hive_results"].value[0].uri != "" + assert references["hive_results"].value[1].uri != "" + assert references["hive_results"].value[0].uri in qubole_hive_jobs[0].query.query + assert references["hive_results"].value[1].uri in qubole_hive_jobs[1].query.query def test_hive_task_dynamic_job_spec_generation(): with _common_utils.AutoDeletingTempDir("user_dir") as user_working_directory: context = _common_engine.EngineContext( - execution_id=WorkflowExecutionIdentifier( - project='unit_test', - domain='unit_test', - name='unit_test' - ), + execution_id=WorkflowExecutionIdentifier(project="unit_test", domain="unit_test", name="unit_test"), execution_date=_datetime.utcnow(), stats=None, # TODO: A mock stats object that we can read later. logging=_logging, # TODO: A mock logging object that we can read later. - tmp_dir=user_working_directory + tmp_dir=user_working_directory, ) dj_spec = two_queries._produce_dynamic_job_spec(context, _literals.LiteralMap(literals={})) diff --git a/tests/flytekit/unit/sdk/tasks/test_pytorch_task.py b/tests/flytekit/unit/sdk/tasks/test_pytorch_task.py index 62cba7cba5..33bff0766f 100644 --- a/tests/flytekit/unit/sdk/tasks/test_pytorch_task.py +++ b/tests/flytekit/unit/sdk/tasks/test_pytorch_task.py @@ -1,11 +1,14 @@ from __future__ import absolute_import -from flytekit.sdk.tasks import pytorch_task, inputs, outputs -from flytekit.sdk.types import Types + +import datetime as _datetime + from flytekit.common import constants as _common_constants -from flytekit.common.tasks import sdk_runnable as _sdk_runnable, pytorch_task as _pytorch_task +from flytekit.common.tasks import pytorch_task as _pytorch_task +from flytekit.common.tasks import sdk_runnable as _sdk_runnable from flytekit.models import types as _type_models from flytekit.models.core import identifier as _identifier -import datetime as _datetime +from flytekit.sdk.tasks import inputs, outputs, pytorch_task +from flytekit.sdk.types import Types @inputs(in1=Types.Integer) @@ -21,25 +24,27 @@ def simple_pytorch_task(wf_params, sc, in1, out1): def test_simple_pytorch_task(): assert isinstance(simple_pytorch_task, _pytorch_task.SdkPyTorchTask) assert isinstance(simple_pytorch_task, _sdk_runnable.SdkRunnableTask) - assert simple_pytorch_task.interface.inputs['in1'].description == '' - assert simple_pytorch_task.interface.inputs['in1'].type == \ - _type_models.LiteralType(simple=_type_models.SimpleType.INTEGER) - assert simple_pytorch_task.interface.outputs['out1'].description == '' - assert simple_pytorch_task.interface.outputs['out1'].type == \ - _type_models.LiteralType(simple=_type_models.SimpleType.STRING) + assert simple_pytorch_task.interface.inputs["in1"].description == "" + assert simple_pytorch_task.interface.inputs["in1"].type == _type_models.LiteralType( + simple=_type_models.SimpleType.INTEGER + ) + assert simple_pytorch_task.interface.outputs["out1"].description == "" + assert simple_pytorch_task.interface.outputs["out1"].type == _type_models.LiteralType( + simple=_type_models.SimpleType.STRING + ) assert simple_pytorch_task.type == _common_constants.SdkTaskType.PYTORCH_TASK - assert simple_pytorch_task.task_function_name == 'simple_pytorch_task' + assert simple_pytorch_task.task_function_name == "simple_pytorch_task" assert simple_pytorch_task.task_module == __name__ assert simple_pytorch_task.metadata.timeout == _datetime.timedelta(seconds=0) - assert simple_pytorch_task.metadata.deprecated_error_message == '' + assert simple_pytorch_task.metadata.deprecated_error_message == "" assert simple_pytorch_task.metadata.discoverable is False - assert simple_pytorch_task.metadata.discovery_version == '' + assert simple_pytorch_task.metadata.discovery_version == "" assert simple_pytorch_task.metadata.retries.retries == 0 assert len(simple_pytorch_task.container.resources.limits) == 0 assert len(simple_pytorch_task.container.resources.requests) == 0 - assert simple_pytorch_task.custom['workers'] == 1 + assert simple_pytorch_task.custom["workers"] == 1 # Should strip out the venv component of the args. - assert simple_pytorch_task._get_container_definition().args[0] == 'pyflyte-execute' + assert simple_pytorch_task._get_container_definition().args[0] == "pyflyte-execute" pb2 = simple_pytorch_task.to_flyte_idl() - assert pb2.custom['workers'] == 1 \ No newline at end of file + assert pb2.custom["workers"] == 1 diff --git a/tests/flytekit/unit/sdk/tasks/test_sagemaker_tasks.py b/tests/flytekit/unit/sdk/tasks/test_sagemaker_tasks.py index ae9bb7dfff..624b09ca4f 100644 --- a/tests/flytekit/unit/sdk/tasks/test_sagemaker_tasks.py +++ b/tests/flytekit/unit/sdk/tasks/test_sagemaker_tasks.py @@ -1,20 +1,30 @@ from __future__ import absolute_import -from flytekit.common.tasks.sagemaker.training_job_task import SdkBuiltinAlgorithmTrainingJobTask -from flytekit.common.tasks.sagemaker.hpo_job_task import SdkSimpleHyperparameterTuningJobTask -from flytekit.common import constants as _common_constants -from flytekit.common.tasks import task as _sdk_task -from flytekit.models.core import identifier as _identifier + import datetime as _datetime -from flytekit.models.sagemaker.training_job import TrainingJobResourceConfig, AlgorithmSpecification, \ - MetricDefinition, AlgorithmName, InputMode, InputContentType + +from flyteidl.plugins.sagemaker.hyperparameter_tuning_job_pb2 import HyperparameterTuningJobConfig as _pb2_HPOJobConfig +from flyteidl.plugins.sagemaker.training_job_pb2 import TrainingJobResourceConfig as _pb2_TrainingJobResourceConfig + # from flytekit.sdk.sagemaker.types import InputMode, AlgorithmName from google.protobuf.json_format import ParseDict -from flyteidl.plugins.sagemaker.training_job_pb2 import TrainingJobResourceConfig as _pb2_TrainingJobResourceConfig -from flyteidl.plugins.sagemaker.hyperparameter_tuning_job_pb2 import HyperparameterTuningJobConfig as _pb2_HPOJobConfig -from flytekit.sdk import types as _sdk_types + +from flytekit.common import constants as _common_constants +from flytekit.common.tasks import task as _sdk_task from flytekit.common.tasks.sagemaker import hpo_job_task +from flytekit.common.tasks.sagemaker.hpo_job_task import SdkSimpleHyperparameterTuningJobTask +from flytekit.common.tasks.sagemaker.training_job_task import SdkBuiltinAlgorithmTrainingJobTask from flytekit.models import types as _idl_types +from flytekit.models.core import identifier as _identifier from flytekit.models.core import types as _core_types +from flytekit.models.sagemaker.training_job import ( + AlgorithmName, + AlgorithmSpecification, + InputContentType, + InputMode, + MetricDefinition, + TrainingJobResourceConfig, +) +from flytekit.sdk import types as _sdk_types example_hyperparams = { "base_score": "0.5", @@ -44,9 +54,7 @@ builtin_algorithm_training_job_task = SdkBuiltinAlgorithmTrainingJobTask( training_job_resource_config=TrainingJobResourceConfig( - instance_type="ml.m4.xlarge", - instance_count=1, - volume_size_in_gb=25, + instance_type="ml.m4.xlarge", instance_count=1, volume_size_in_gb=25, ), algorithm_specification=AlgorithmSpecification( input_mode=InputMode.FILE, @@ -57,62 +65,62 @@ ) builtin_algorithm_training_job_task._id = _identifier.Identifier( - _identifier.ResourceType.TASK, "my_project", "my_domain", "my_name", "my_version") + _identifier.ResourceType.TASK, "my_project", "my_domain", "my_name", "my_version" +) def test_builtin_algorithm_training_job_task(): assert isinstance(builtin_algorithm_training_job_task, SdkBuiltinAlgorithmTrainingJobTask) assert isinstance(builtin_algorithm_training_job_task, _sdk_task.SdkTask) - assert builtin_algorithm_training_job_task.interface.inputs['train'].description == '' - assert builtin_algorithm_training_job_task.interface.inputs['train'].type == \ - _idl_types.LiteralType( - blob=_core_types.BlobType( - format="csv", - dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART - ) - ) - assert builtin_algorithm_training_job_task.interface.inputs['train'].type == \ - _sdk_types.Types.MultiPartCSV.to_flyte_literal_type() - assert builtin_algorithm_training_job_task.interface.inputs['validation'].description == '' - assert builtin_algorithm_training_job_task.interface.inputs['validation'].type == \ - _sdk_types.Types.MultiPartCSV.to_flyte_literal_type() - assert builtin_algorithm_training_job_task.interface.inputs['train'].type == \ - _idl_types.LiteralType( - blob=_core_types.BlobType( - format="csv", - dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART - ) - ) - assert builtin_algorithm_training_job_task.interface.inputs['static_hyperparameters'].description == '' - assert builtin_algorithm_training_job_task.interface.inputs['static_hyperparameters'].type == \ - _sdk_types.Types.Generic.to_flyte_literal_type() - assert builtin_algorithm_training_job_task.interface.outputs['model'].description == '' - assert builtin_algorithm_training_job_task.interface.outputs['model'].type == \ - _sdk_types.Types.Blob.to_flyte_literal_type() + assert builtin_algorithm_training_job_task.interface.inputs["train"].description == "" + assert builtin_algorithm_training_job_task.interface.inputs["train"].type == _idl_types.LiteralType( + blob=_core_types.BlobType(format="csv", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART,) + ) + assert ( + builtin_algorithm_training_job_task.interface.inputs["train"].type + == _sdk_types.Types.MultiPartCSV.to_flyte_literal_type() + ) + assert builtin_algorithm_training_job_task.interface.inputs["validation"].description == "" + assert ( + builtin_algorithm_training_job_task.interface.inputs["validation"].type + == _sdk_types.Types.MultiPartCSV.to_flyte_literal_type() + ) + assert builtin_algorithm_training_job_task.interface.inputs["train"].type == _idl_types.LiteralType( + blob=_core_types.BlobType(format="csv", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART,) + ) + assert builtin_algorithm_training_job_task.interface.inputs["static_hyperparameters"].description == "" + assert ( + builtin_algorithm_training_job_task.interface.inputs["static_hyperparameters"].type + == _sdk_types.Types.Generic.to_flyte_literal_type() + ) + assert builtin_algorithm_training_job_task.interface.outputs["model"].description == "" + assert ( + builtin_algorithm_training_job_task.interface.outputs["model"].type + == _sdk_types.Types.Blob.to_flyte_literal_type() + ) assert builtin_algorithm_training_job_task.type == _common_constants.SdkTaskType.SAGEMAKER_TRAINING_JOB_TASK assert builtin_algorithm_training_job_task.metadata.timeout == _datetime.timedelta(seconds=0) - assert builtin_algorithm_training_job_task.metadata.deprecated_error_message == '' + assert builtin_algorithm_training_job_task.metadata.deprecated_error_message == "" assert builtin_algorithm_training_job_task.metadata.discoverable is False - assert builtin_algorithm_training_job_task.metadata.discovery_version == '' + assert builtin_algorithm_training_job_task.metadata.discovery_version == "" assert builtin_algorithm_training_job_task.metadata.retries.retries == 0 assert "metricDefinitions" not in builtin_algorithm_training_job_task.custom["algorithmSpecification"].keys() - ParseDict(builtin_algorithm_training_job_task.custom['trainingJobResourceConfig'], - _pb2_TrainingJobResourceConfig) # fails the test if it cannot be parsed + ParseDict( + builtin_algorithm_training_job_task.custom["trainingJobResourceConfig"], _pb2_TrainingJobResourceConfig, + ) # fails the test if it cannot be parsed builtin_algorithm_training_job_task2 = SdkBuiltinAlgorithmTrainingJobTask( training_job_resource_config=TrainingJobResourceConfig( - instance_type="ml.m4.xlarge", - instance_count=1, - volume_size_in_gb=25, + instance_type="ml.m4.xlarge", instance_count=1, volume_size_in_gb=25, ), algorithm_specification=AlgorithmSpecification( input_mode=InputMode.FILE, input_content_type=InputContentType.TEXT_CSV, algorithm_name=AlgorithmName.XGBOOST, algorithm_version="0.72", - metric_definitions=[MetricDefinition(name="Validation error", regex="validation:error")] + metric_definitions=[MetricDefinition(name="Validation error", regex="validation:error")], ), ) @@ -120,66 +128,67 @@ def test_builtin_algorithm_training_job_task(): training_job=builtin_algorithm_training_job_task2, max_number_of_training_jobs=10, max_parallel_training_jobs=5, - cache_version='1', + cache_version="1", retries=2, cacheable=True, ) simple_xgboost_hpo_job_task._id = _identifier.Identifier( - _identifier.ResourceType.TASK, "my_project", "my_domain", "my_name", "my_version") + _identifier.ResourceType.TASK, "my_project", "my_domain", "my_name", "my_version" +) def test_simple_hpo_job_task(): assert isinstance(simple_xgboost_hpo_job_task, SdkSimpleHyperparameterTuningJobTask) assert isinstance(simple_xgboost_hpo_job_task, _sdk_task.SdkTask) # Checking if the input of the underlying SdkTrainingJobTask has been embedded - assert simple_xgboost_hpo_job_task.interface.inputs['train'].description == '' - assert simple_xgboost_hpo_job_task.interface.inputs['train'].type == \ - _sdk_types.Types.MultiPartCSV.to_flyte_literal_type() - assert simple_xgboost_hpo_job_task.interface.inputs['train'].type == \ - _idl_types.LiteralType( - blob=_core_types.BlobType( - format="csv", - dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART - ) - ) - assert simple_xgboost_hpo_job_task.interface.inputs['validation'].description == '' - assert simple_xgboost_hpo_job_task.interface.inputs['validation'].type == \ - _sdk_types.Types.MultiPartCSV.to_flyte_literal_type() - assert simple_xgboost_hpo_job_task.interface.inputs['validation'].type == \ - _idl_types.LiteralType( - blob=_core_types.BlobType( - format="csv", - dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART - ) - ) - assert simple_xgboost_hpo_job_task.interface.inputs['static_hyperparameters'].description == '' - assert simple_xgboost_hpo_job_task.interface.inputs['static_hyperparameters'].type == \ - _sdk_types.Types.Generic.to_flyte_literal_type() + assert simple_xgboost_hpo_job_task.interface.inputs["train"].description == "" + assert ( + simple_xgboost_hpo_job_task.interface.inputs["train"].type + == _sdk_types.Types.MultiPartCSV.to_flyte_literal_type() + ) + assert simple_xgboost_hpo_job_task.interface.inputs["train"].type == _idl_types.LiteralType( + blob=_core_types.BlobType(format="csv", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART,) + ) + assert simple_xgboost_hpo_job_task.interface.inputs["validation"].description == "" + assert ( + simple_xgboost_hpo_job_task.interface.inputs["validation"].type + == _sdk_types.Types.MultiPartCSV.to_flyte_literal_type() + ) + assert simple_xgboost_hpo_job_task.interface.inputs["validation"].type == _idl_types.LiteralType( + blob=_core_types.BlobType(format="csv", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART,) + ) + assert simple_xgboost_hpo_job_task.interface.inputs["static_hyperparameters"].description == "" + assert ( + simple_xgboost_hpo_job_task.interface.inputs["static_hyperparameters"].type + == _sdk_types.Types.Generic.to_flyte_literal_type() + ) # Checking if the hpo-specific input is defined - assert simple_xgboost_hpo_job_task.interface.inputs['hyperparameter_tuning_job_config'].description == '' - assert simple_xgboost_hpo_job_task.interface.inputs['hyperparameter_tuning_job_config'].type == \ - _sdk_types.Types.Proto(_pb2_HPOJobConfig).to_flyte_literal_type() - assert simple_xgboost_hpo_job_task.interface.outputs['model'].description == '' - assert simple_xgboost_hpo_job_task.interface.outputs['model'].type == \ - _sdk_types.Types.Blob.to_flyte_literal_type() + assert simple_xgboost_hpo_job_task.interface.inputs["hyperparameter_tuning_job_config"].description == "" + assert ( + simple_xgboost_hpo_job_task.interface.inputs["hyperparameter_tuning_job_config"].type + == _sdk_types.Types.Proto(_pb2_HPOJobConfig).to_flyte_literal_type() + ) + assert simple_xgboost_hpo_job_task.interface.outputs["model"].description == "" + assert simple_xgboost_hpo_job_task.interface.outputs["model"].type == _sdk_types.Types.Blob.to_flyte_literal_type() assert simple_xgboost_hpo_job_task.type == _common_constants.SdkTaskType.SAGEMAKER_HYPERPARAMETER_TUNING_JOB_TASK # Checking if the spec of the TrainingJob is embedded into the custom field # of this SdkSimpleHyperparameterTuningJobTask assert simple_xgboost_hpo_job_task.to_flyte_idl().custom["trainingJob"] == ( - builtin_algorithm_training_job_task2.to_flyte_idl().custom) + builtin_algorithm_training_job_task2.to_flyte_idl().custom + ) assert simple_xgboost_hpo_job_task.metadata.timeout == _datetime.timedelta(seconds=0) assert simple_xgboost_hpo_job_task.metadata.discoverable is True - assert simple_xgboost_hpo_job_task.metadata.discovery_version == '1' + assert simple_xgboost_hpo_job_task.metadata.discovery_version == "1" assert simple_xgboost_hpo_job_task.metadata.retries.retries == 2 - assert simple_xgboost_hpo_job_task.metadata.deprecated_error_message == '' + assert simple_xgboost_hpo_job_task.metadata.deprecated_error_message == "" assert "metricDefinitions" in simple_xgboost_hpo_job_task.custom["trainingJob"]["algorithmSpecification"].keys() assert len(simple_xgboost_hpo_job_task.custom["trainingJob"]["algorithmSpecification"]["metricDefinitions"]) == 1 - """ These are attributes for SdkRunnable. We will need these when supporting CustomTrainingJobTask and CustomHPOJobTask + """ These are attributes for SdkRunnable. We will need these when supporting CustomTrainingJobTask and CustomHPOJobTask assert simple_xgboost_hpo_job_task.task_module == __name__ assert simple_xgboost_hpo_job_task._get_container_definition().args[0] == 'pyflyte-execute' """ diff --git a/tests/flytekit/unit/sdk/tasks/test_sidecar_tasks.py b/tests/flytekit/unit/sdk/tasks/test_sidecar_tasks.py index 65d9b4ea28..72989eeffc 100644 --- a/tests/flytekit/unit/sdk/tasks/test_sidecar_tasks.py +++ b/tests/flytekit/unit/sdk/tasks/test_sidecar_tasks.py @@ -1,40 +1,33 @@ from __future__ import absolute_import import mock -from flytekit.configuration.internal import IMAGE as _IMAGE +from k8s.io.api.core.v1 import generated_pb2 -from flytekit.common.tasks import task as _sdk_task, sidecar_task as _sidecar_task +from flytekit.common.tasks import sidecar_task as _sidecar_task +from flytekit.common.tasks import task as _sdk_task +from flytekit.configuration.internal import IMAGE as _IMAGE from flytekit.models.core import identifier as _identifier - from flytekit.sdk.tasks import inputs, outputs, sidecar_task from flytekit.sdk.types import Types -from k8s.io.api.core.v1 import generated_pb2 - def get_pod_spec(): - a_container = generated_pb2.Container( - name="a container", - ) + a_container = generated_pb2.Container(name="a container",) a_container.command.extend(["fee", "fi", "fo", "fum"]) - a_container.volumeMounts.extend([generated_pb2.VolumeMount( - name="volume mount", - mountPath="some/where", - )]) + a_container.volumeMounts.extend([generated_pb2.VolumeMount(name="volume mount", mountPath="some/where",)]) - pod_spec = generated_pb2.PodSpec( - restartPolicy="OnFailure", - ) + pod_spec = generated_pb2.PodSpec(restartPolicy="OnFailure",) pod_spec.containers.extend([a_container, generated_pb2.Container(name="another container")]) return pod_spec -with mock.patch.object(_IMAGE, 'get', return_value='docker.io/blah:abc123'): +with mock.patch.object(_IMAGE, "get", return_value="docker.io/blah:abc123"): + @inputs(in1=Types.Integer) @outputs(out1=Types.String) @sidecar_task( - cpu_request='10', - gpu_limit='2', + cpu_request="10", + gpu_limit="2", environment={"foo": "bar"}, pod_spec=get_pod_spec(), primary_container_name="a container", @@ -50,18 +43,27 @@ def test_sidecar_task(): assert isinstance(simple_sidecar_task, _sdk_task.SdkTask) assert isinstance(simple_sidecar_task, _sidecar_task.SdkSidecarTask) - pod_spec = simple_sidecar_task.custom['podSpec'] - assert pod_spec['restartPolicy'] == 'OnFailure' - assert len(pod_spec['containers']) == 2 - primary_container = pod_spec['containers'][0] - assert primary_container['name'] == 'a container' - assert primary_container['args'] == ['pyflyte-execute', '--task-module', - 'tests.flytekit.unit.sdk.tasks.test_sidecar_tasks', '--task-name', - 'simple_sidecar_task', '--inputs', '{{.input}}', '--output-prefix', - '{{.outputPrefix}}'] - assert primary_container['volumeMounts'] == [{'mountPath': 'some/where', 'name': 'volume mount'}] - assert {'name': 'foo', 'value': 'bar'} in primary_container['env'] - assert primary_container['resources'] == {'requests': {'cpu': {'string': '10'}}, - 'limits': {'gpu': {'string': '2'}}} - assert pod_spec['containers'][1]['name'] == 'another container' - assert simple_sidecar_task.custom['primaryContainerName'] == 'a container' + pod_spec = simple_sidecar_task.custom["podSpec"] + assert pod_spec["restartPolicy"] == "OnFailure" + assert len(pod_spec["containers"]) == 2 + primary_container = pod_spec["containers"][0] + assert primary_container["name"] == "a container" + assert primary_container["args"] == [ + "pyflyte-execute", + "--task-module", + "tests.flytekit.unit.sdk.tasks.test_sidecar_tasks", + "--task-name", + "simple_sidecar_task", + "--inputs", + "{{.input}}", + "--output-prefix", + "{{.outputPrefix}}", + ] + assert primary_container["volumeMounts"] == [{"mountPath": "some/where", "name": "volume mount"}] + assert {"name": "foo", "value": "bar"} in primary_container["env"] + assert primary_container["resources"] == { + "requests": {"cpu": {"string": "10"}}, + "limits": {"gpu": {"string": "2"}}, + } + assert pod_spec["containers"][1]["name"] == "another container" + assert simple_sidecar_task.custom["primaryContainerName"] == "a container" diff --git a/tests/flytekit/unit/sdk/tasks/test_spark_task.py b/tests/flytekit/unit/sdk/tasks/test_spark_task.py index e0bcc6e7b3..31418a6b46 100644 --- a/tests/flytekit/unit/sdk/tasks/test_spark_task.py +++ b/tests/flytekit/unit/sdk/tasks/test_spark_task.py @@ -1,19 +1,22 @@ from __future__ import absolute_import + +import datetime as _datetime +import os as _os +import sys as _sys + from flytekit.bin import entrypoint as _entrypoint -from flytekit.sdk.tasks import spark_task, inputs, outputs -from flytekit.sdk.types import Types from flytekit.common import constants as _common_constants -from flytekit.common.tasks import sdk_runnable as _sdk_runnable, spark_task as _spark_task +from flytekit.common.tasks import sdk_runnable as _sdk_runnable +from flytekit.common.tasks import spark_task as _spark_task from flytekit.models import types as _type_models from flytekit.models.core import identifier as _identifier -import datetime as _datetime -import os as _os -import sys as _sys +from flytekit.sdk.tasks import inputs, outputs, spark_task +from flytekit.sdk.types import Types @inputs(in1=Types.Integer) @outputs(out1=Types.String) -@spark_task(spark_conf={'A': 'B'}, hadoop_conf={'C': 'D'}) +@spark_task(spark_conf={"A": "B"}, hadoop_conf={"C": "D"}) def default_task(wf_params, sc, in1, out1): pass @@ -24,27 +27,27 @@ def default_task(wf_params, sc, in1, out1): def test_default_python_task(): assert isinstance(default_task, _spark_task.SdkSparkTask) assert isinstance(default_task, _sdk_runnable.SdkRunnableTask) - assert default_task.interface.inputs['in1'].description == '' - assert default_task.interface.inputs['in1'].type == \ - _type_models.LiteralType(simple=_type_models.SimpleType.INTEGER) - assert default_task.interface.outputs['out1'].description == '' - assert default_task.interface.outputs['out1'].type == \ - _type_models.LiteralType(simple=_type_models.SimpleType.STRING) + assert default_task.interface.inputs["in1"].description == "" + assert default_task.interface.inputs["in1"].type == _type_models.LiteralType(simple=_type_models.SimpleType.INTEGER) + assert default_task.interface.outputs["out1"].description == "" + assert default_task.interface.outputs["out1"].type == _type_models.LiteralType( + simple=_type_models.SimpleType.STRING + ) assert default_task.type == _common_constants.SdkTaskType.SPARK_TASK - assert default_task.task_function_name == 'default_task' + assert default_task.task_function_name == "default_task" assert default_task.task_module == __name__ assert default_task.metadata.timeout == _datetime.timedelta(seconds=0) - assert default_task.metadata.deprecated_error_message == '' + assert default_task.metadata.deprecated_error_message == "" assert default_task.metadata.discoverable is False - assert default_task.metadata.discovery_version == '' + assert default_task.metadata.discovery_version == "" assert default_task.metadata.retries.retries == 0 assert len(default_task.container.resources.limits) == 0 assert len(default_task.container.resources.requests) == 0 - assert default_task.custom['sparkConf']['A'] == 'B' - assert default_task.custom['hadoopConf']['C'] == 'D' - assert _os.path.abspath(_entrypoint.__file__)[:-1] in default_task.custom['mainApplicationFile'] - assert default_task.custom['executorPath'] == _sys.executable + assert default_task.custom["sparkConf"]["A"] == "B" + assert default_task.custom["hadoopConf"]["C"] == "D" + assert _os.path.abspath(_entrypoint.__file__)[:-1] in default_task.custom["mainApplicationFile"] + assert default_task.custom["executorPath"] == _sys.executable pb2 = default_task.to_flyte_idl() - assert pb2.custom['sparkConf']['A'] == 'B' - assert pb2.custom['hadoopConf']['C'] == 'D' + assert pb2.custom["sparkConf"]["A"] == "B" + assert pb2.custom["hadoopConf"]["C"] == "D" diff --git a/tests/flytekit/unit/sdk/tasks/test_tasks.py b/tests/flytekit/unit/sdk/tasks/test_tasks.py index 338cf0e8b9..30a13eaee0 100644 --- a/tests/flytekit/unit/sdk/tasks/test_tasks.py +++ b/tests/flytekit/unit/sdk/tasks/test_tasks.py @@ -1,13 +1,16 @@ from __future__ import absolute_import -from flytekit.sdk.tasks import python_task, inputs, outputs -from flytekit.sdk.types import Types + +import datetime as _datetime +import os as _os + +from flytekit import configuration as _configuration from flytekit.common import constants as _common_constants from flytekit.common.tasks import sdk_runnable as _sdk_runnable -from flytekit import configuration as _configuration -from flytekit.models import types as _type_models, task as _task_models +from flytekit.models import task as _task_models +from flytekit.models import types as _type_models from flytekit.models.core import identifier as _identifier -import datetime as _datetime -import os as _os +from flytekit.sdk.tasks import inputs, outputs, python_task +from flytekit.sdk.types import Types @inputs(in1=Types.Integer) @@ -22,19 +25,19 @@ def default_task(wf_params, in1, out1): def test_default_python_task(): assert isinstance(default_task, _sdk_runnable.SdkRunnableTask) - assert default_task.interface.inputs['in1'].description == '' - assert default_task.interface.inputs['in1'].type == \ - _type_models.LiteralType(simple=_type_models.SimpleType.INTEGER) - assert default_task.interface.outputs['out1'].description == '' - assert default_task.interface.outputs['out1'].type == \ - _type_models.LiteralType(simple=_type_models.SimpleType.STRING) + assert default_task.interface.inputs["in1"].description == "" + assert default_task.interface.inputs["in1"].type == _type_models.LiteralType(simple=_type_models.SimpleType.INTEGER) + assert default_task.interface.outputs["out1"].description == "" + assert default_task.interface.outputs["out1"].type == _type_models.LiteralType( + simple=_type_models.SimpleType.STRING + ) assert default_task.type == _common_constants.SdkTaskType.PYTHON_TASK - assert default_task.task_function_name == 'default_task' + assert default_task.task_function_name == "default_task" assert default_task.task_module == __name__ assert default_task.metadata.timeout == _datetime.timedelta(seconds=0) - assert default_task.metadata.deprecated_error_message == '' + assert default_task.metadata.deprecated_error_message == "" assert default_task.metadata.discoverable is False - assert default_task.metadata.discovery_version == '' + assert default_task.metadata.discovery_version == "" assert default_task.metadata.retries.retries == 0 assert len(default_task.container.resources.limits) == 0 assert len(default_task.container.resources.requests) == 0 @@ -42,26 +45,18 @@ def test_default_python_task(): def test_default_resources(): with _configuration.TemporaryConfiguration( - _os.path.join( - _os.path.dirname(_os.path.realpath(__file__)), - '../../configuration/configs/good.config' - ) + _os.path.join(_os.path.dirname(_os.path.realpath(__file__)), "../../configuration/configs/good.config",) ): + @inputs(in1=Types.Integer) @outputs(out1=Types.String) @python_task() def default_task2(wf_params, in1, out1): pass - request_map = { - r.name: r.value - for r in default_task2.container.resources.requests - } + request_map = {r.name: r.value for r in default_task2.container.resources.requests} - limit_map = { - l.name: l.value - for l in default_task2.container.resources.limits - } + limit_map = {l.name: l.value for l in default_task2.container.resources.limits} assert request_map[_task_models.Resources.ResourceName.CPU] == "500m" assert request_map[_task_models.Resources.ResourceName.MEMORY] == "500Gi" @@ -76,11 +71,9 @@ def default_task2(wf_params, in1, out1): def test_overriden_resources(): with _configuration.TemporaryConfiguration( - _os.path.join( - _os.path.dirname(_os.path.realpath(__file__)), - '../../configuration/configs/good.config' - ) + _os.path.join(_os.path.dirname(_os.path.realpath(__file__)), "../../configuration/configs/good.config",) ): + @inputs(in1=Types.Integer) @outputs(out1=Types.String) @python_task( @@ -91,20 +84,14 @@ def test_overriden_resources(): gpu_limit="1", gpu_request="0", storage_request="100Gi", - storage_limit="200Gi" + storage_limit="200Gi", ) def default_task2(wf_params, in1, out1): pass - request_map = { - r.name: r.value - for r in default_task2.container.resources.requests - } + request_map = {r.name: r.value for r in default_task2.container.resources.requests} - limit_map = { - l.name: l.value - for l in default_task2.container.resources.limits - } + limit_map = {l.name: l.value for l in default_task2.container.resources.limits} assert request_map[_task_models.Resources.ResourceName.CPU] == "500m" assert request_map[_task_models.Resources.ResourceName.MEMORY] == "50Gi" diff --git a/tests/flytekit/unit/sdk/test_workflow.py b/tests/flytekit/unit/sdk/test_workflow.py index eee4efda91..fd2dcf7d2a 100644 --- a/tests/flytekit/unit/sdk/test_workflow.py +++ b/tests/flytekit/unit/sdk/test_workflow.py @@ -1,19 +1,18 @@ from __future__ import absolute_import import pytest -import datetime from flytekit.common import constants from flytekit.common.exceptions import user as _user_exceptions -from flytekit.common.types import primitives, base_sdk_types, containers -from flytekit.sdk.tasks import python_task, inputs, outputs +from flytekit.common.types import base_sdk_types, containers, primitives +from flytekit.sdk.tasks import inputs, outputs, python_task from flytekit.sdk.types import Types -from flytekit.sdk.workflow import workflow, Input, Output, workflow_class +from flytekit.sdk.workflow import Input, Output, workflow, workflow_class def test_input(): i = Input(primitives.Integer, help="blah", default=None) - assert i.name == '' + assert i.name == "" assert i.sdk_default is None assert i.default == base_sdk_types.Void() assert i.sdk_required is False @@ -22,8 +21,8 @@ def test_input(): assert i.var.description == "blah" assert i.sdk_type == primitives.Integer - i = i.rename_and_return_reference('new_name') - assert i.name == 'new_name' + i = i.rename_and_return_reference("new_name") + assert i.name == "new_name" assert i.sdk_default is None assert i.default == base_sdk_types.Void() assert i.sdk_required is False @@ -33,8 +32,8 @@ def test_input(): assert i.sdk_type == primitives.Integer i = Input(primitives.Integer, default=1) - assert i.name == '' - assert i.sdk_default is 1 + assert i.name == "" + assert i.sdk_default == 1 assert i.default == primitives.Integer(1) assert i.sdk_required is False assert i.required is None @@ -42,9 +41,9 @@ def test_input(): assert i.var.description == "" assert i.sdk_type == primitives.Integer - i = i.rename_and_return_reference('new_name') - assert i.name == 'new_name' - assert i.sdk_default is 1 + i = i.rename_and_return_reference("new_name") + assert i.name == "new_name" + assert i.sdk_default == 1 assert i.default == primitives.Integer(1) assert i.sdk_required is False assert i.required is None @@ -56,7 +55,7 @@ def test_input(): Input(primitives.Integer, required=True, default=1) i = Input([primitives.Integer], default=[1, 2]) - assert i.name == '' + assert i.name == "" assert i.sdk_default == [1, 2] assert i.default == containers.List(primitives.Integer)([primitives.Integer(1), primitives.Integer(2)]) assert i.sdk_required is False @@ -65,8 +64,8 @@ def test_input(): assert i.var.description == "" assert i.sdk_type == containers.List(primitives.Integer) - i = i.rename_and_return_reference('new_name') - assert i.name == 'new_name' + i = i.rename_and_return_reference("new_name") + assert i.name == "new_name" assert i.sdk_default == [1, 2] assert i.default == containers.List(primitives.Integer)([primitives.Integer(1), primitives.Integer(2)]) assert i.sdk_required is False @@ -78,13 +77,13 @@ def test_input(): def test_output(): o = Output(1, sdk_type=primitives.Integer, help="blah") - assert o.name == '' + assert o.name == "" assert o.var.description == "blah" assert o.var.type == primitives.Integer.to_flyte_literal_type() assert o.binding_data.scalar.primitive.integer == 1 - o = o.rename_and_return_reference('new_name') - assert o.name == 'new_name' + o = o.rename_and_return_reference("new_name") + assert o.name == "new_name" assert o.var.description == "blah" assert o.var.type == primitives.Integer.to_flyte_literal_type() assert o.binding_data.scalar.primitive.integer == 1 @@ -98,7 +97,6 @@ def _get_node_by_id(wf, nid): def test_workflow_no_node_dependencies_or_outputs(): - @inputs(a=Types.Integer) @outputs(b=Types.Integer) @python_task @@ -106,33 +104,29 @@ def my_task(wf_params, a, b): b.set(a + 1) i1 = Input(Types.Integer) - i2 = Input(Types.Integer, default=5, help='Not required.') + i2 = Input(Types.Integer, default=5, help="Not required.") - input_dict = { - 'input_1': i1, - 'input_2': i2 - } + input_dict = {"input_1": i1, "input_2": i2} nodes = { - 'a': my_task(a=input_dict['input_1']), - 'b': my_task(a=input_dict['input_2']), - 'c': my_task(a=100) + "a": my_task(a=input_dict["input_1"]), + "b": my_task(a=input_dict["input_2"]), + "c": my_task(a=100), } w = workflow(inputs=input_dict, outputs={}, nodes=nodes) - assert w.interface.inputs['input_1'].type == Types.Integer.to_flyte_literal_type() - assert w.interface.inputs['input_2'].type == Types.Integer.to_flyte_literal_type() - assert _get_node_by_id(w, 'a').inputs[0].var == 'a' - assert _get_node_by_id(w, 'a').inputs[0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID - assert _get_node_by_id(w, 'a').inputs[0].binding.promise.var == 'input_1' - assert _get_node_by_id(w, 'b').inputs[0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID - assert _get_node_by_id(w, 'b').inputs[0].binding.promise.var == 'input_2' - assert _get_node_by_id(w, 'c').inputs[0].binding.scalar.primitive.integer == 100 + assert w.interface.inputs["input_1"].type == Types.Integer.to_flyte_literal_type() + assert w.interface.inputs["input_2"].type == Types.Integer.to_flyte_literal_type() + assert _get_node_by_id(w, "a").inputs[0].var == "a" + assert _get_node_by_id(w, "a").inputs[0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID + assert _get_node_by_id(w, "a").inputs[0].binding.promise.var == "input_1" + assert _get_node_by_id(w, "b").inputs[0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID + assert _get_node_by_id(w, "b").inputs[0].binding.promise.var == "input_2" + assert _get_node_by_id(w, "c").inputs[0].binding.scalar.primitive.integer == 100 def test_workflow_metaclass_no_node_dependencies_or_outputs(): - @inputs(a=Types.Integer) @outputs(b=Types.Integer) @python_task @@ -142,17 +136,17 @@ def my_task(wf_params, a, b): @workflow_class class sup(object): input_1 = Input(Types.Integer) - input_2 = Input(Types.Integer, default=5, help='Not required.') + input_2 = Input(Types.Integer, default=5, help="Not required.") a = my_task(a=input_1) b = my_task(a=input_2) c = my_task(a=100) - assert sup.interface.inputs['input_1'].type == Types.Integer.to_flyte_literal_type() - assert sup.interface.inputs['input_2'].type == Types.Integer.to_flyte_literal_type() - assert _get_node_by_id(sup, 'a').inputs[0].var == 'a' - assert _get_node_by_id(sup, 'a').inputs[0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID - assert _get_node_by_id(sup, 'a').inputs[0].binding.promise.var == 'input_1' - assert _get_node_by_id(sup, 'b').inputs[0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID - assert _get_node_by_id(sup, 'b').inputs[0].binding.promise.var == 'input_2' - assert _get_node_by_id(sup, 'c').inputs[0].binding.scalar.primitive.integer == 100 + assert sup.interface.inputs["input_1"].type == Types.Integer.to_flyte_literal_type() + assert sup.interface.inputs["input_2"].type == Types.Integer.to_flyte_literal_type() + assert _get_node_by_id(sup, "a").inputs[0].var == "a" + assert _get_node_by_id(sup, "a").inputs[0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID + assert _get_node_by_id(sup, "a").inputs[0].binding.promise.var == "input_1" + assert _get_node_by_id(sup, "b").inputs[0].binding.promise.node_id == constants.GLOBAL_INPUT_NODE_ID + assert _get_node_by_id(sup, "b").inputs[0].binding.promise.var == "input_2" + assert _get_node_by_id(sup, "c").inputs[0].binding.scalar.primitive.integer == 100 diff --git a/tests/flytekit/unit/sdk/types/test_blobs.py b/tests/flytekit/unit/sdk/types/test_blobs.py index 59d340f850..cc82ee27ce 100644 --- a/tests/flytekit/unit/sdk/types/test_blobs.py +++ b/tests/flytekit/unit/sdk/types/test_blobs.py @@ -1,9 +1,10 @@ from __future__ import absolute_import -from flytekit.sdk import types as _sdk_types -from flytekit.common.types.impl import blobs as _blob_impl import pytest +from flytekit.common.types.impl import blobs as _blob_impl +from flytekit.sdk import types as _sdk_types + @pytest.mark.parametrize( "blob_tuple", @@ -12,7 +13,7 @@ (_sdk_types.Types.CSV, _blob_impl.Blob), (_sdk_types.Types.MultiPartBlob, _blob_impl.MultiPartBlob), (_sdk_types.Types.MultiPartCSV, _blob_impl.MultiPartBlob), - ] + ], ) def test_instantiable_blobs(blob_tuple): sdk_type, impl = blob_tuple @@ -28,5 +29,5 @@ def test_instantiable_blobs(blob_tuple): with pytest.raises(Exception): sdk_type(a=1) - blob_inst = sdk_type.create_at_known_location('abc') + blob_inst = sdk_type.create_at_known_location("abc") assert isinstance(blob_inst, impl) diff --git a/tests/flytekit/unit/sdk/types/test_primitives.py b/tests/flytekit/unit/sdk/types/test_primitives.py index 3947b43bb7..8691963307 100644 --- a/tests/flytekit/unit/sdk/types/test_primitives.py +++ b/tests/flytekit/unit/sdk/types/test_primitives.py @@ -1,7 +1,9 @@ from __future__ import absolute_import -from flytekit.sdk import types as _sdk_types + import pytest +from flytekit.sdk import types as _sdk_types + def test_integer(): with pytest.raises(Exception): diff --git a/tests/flytekit/unit/sdk/types/test_schema.py b/tests/flytekit/unit/sdk/types/test_schema.py index 066a98686c..a8510b8ffa 100644 --- a/tests/flytekit/unit/sdk/types/test_schema.py +++ b/tests/flytekit/unit/sdk/types/test_schema.py @@ -1,10 +1,11 @@ from __future__ import absolute_import -from flytekit.sdk.types import Types -from flytekit.sdk.tasks import inputs, outputs, python_task -from flytekit.common.exceptions import user as _user_exceptions import pytest +from flytekit.common.exceptions import user as _user_exceptions +from flytekit.sdk.tasks import inputs, outputs, python_task +from flytekit.sdk.types import Types + def test_generic_schema(): @inputs(a=Types.Schema()) @@ -15,8 +16,8 @@ def fake_task(wf_params, a, b): def test_typed_schema(): - @inputs(a=Types.Schema([('a', Types.Integer), ('b', Types.Integer)])) - @outputs(b=Types.Schema([('a', Types.Integer), ('b', Types.Integer)])) + @inputs(a=Types.Schema([("a", Types.Integer), ("b", Types.Integer)])) + @outputs(b=Types.Schema([("a", Types.Integer), ("b", Types.Integer)])) @python_task def fake_task(wf_params, a, b): pass @@ -29,21 +30,21 @@ def test_bad_definition(): def test_bad_column_types(): with pytest.raises(_user_exceptions.FlyteTypeException): - Types.Schema([('a', Types.Blob)]) + Types.Schema([("a", Types.Blob)]) with pytest.raises(_user_exceptions.FlyteTypeException): - Types.Schema([('a', Types.MultiPartBlob)]) + Types.Schema([("a", Types.MultiPartBlob)]) with pytest.raises(_user_exceptions.FlyteTypeException): - Types.Schema([('a', Types.MultiPartCSV)]) + Types.Schema([("a", Types.MultiPartCSV)]) with pytest.raises(_user_exceptions.FlyteTypeException): - Types.Schema([('a', Types.CSV)]) + Types.Schema([("a", Types.CSV)]) with pytest.raises(_user_exceptions.FlyteTypeException): - Types.Schema([('a', Types.Schema())]) + Types.Schema([("a", Types.Schema())]) def test_create_from_hive_query(): s, q = Types.Schema().create_from_hive_query("SELECT * FROM table", known_location="s3://somewhere/") - assert s.mode == 'wb' + assert s.mode == "wb" assert s.local_path is None assert s.remote_location == "s3://somewhere/" assert "SELECT * FROM table" in q diff --git a/tests/flytekit/unit/test_plugins.py b/tests/flytekit/unit/test_plugins.py index cc8183fa99..611e50bcd3 100644 --- a/tests/flytekit/unit/test_plugins.py +++ b/tests/flytekit/unit/test_plugins.py @@ -1,13 +1,16 @@ from __future__ import absolute_import + +import pytest + from flytekit import plugins from flytekit.tools import lazy_loader -import pytest @pytest.mark.run(order=0) def test_spark_plugin(): plugins.pyspark.SparkContext import pyspark + assert plugins.pyspark.SparkContext == pyspark.SparkContext @@ -17,6 +20,7 @@ def test_schema_plugin(): plugins.pandas.DataFrame import numpy import pandas + assert plugins.numpy.dtype == numpy.dtype assert pandas.DataFrame == pandas.DataFrame @@ -24,9 +28,10 @@ def test_schema_plugin(): @pytest.mark.run(order=2) def test_sidecar_plugin(): assert isinstance(plugins.k8s.io.api.core.v1.generated_pb2, lazy_loader._LazyLoadModule) - assert isinstance(plugins.k8s.io.apimachinery.pkg.api.resource.generated_pb2, lazy_loader._LazyLoadModule) + assert isinstance(plugins.k8s.io.apimachinery.pkg.api.resource.generated_pb2, lazy_loader._LazyLoadModule,) import k8s.io.api.core.v1.generated_pb2 import k8s.io.apimachinery.pkg.api.resource.generated_pb2 + k8s.io.api.core.v1.generated_pb2.Container k8s.io.apimachinery.pkg.api.resource.generated_pb2.Quantity @@ -37,5 +42,6 @@ def test_hive_sensor_plugin(): assert isinstance(plugins.hmsclient.genthrift.hive_metastore.ttypes, lazy_loader._LazyLoadModule) import hmsclient import hmsclient.genthrift.hive_metastore.ttypes + hmsclient.HMSClient hmsclient.genthrift.hive_metastore.ttypes.NoSuchObjectException diff --git a/tests/flytekit/unit/tools/test_aws.py b/tests/flytekit/unit/tools/test_aws.py index 9cdc923e8b..5107c8b1e7 100644 --- a/tests/flytekit/unit/tools/test_aws.py +++ b/tests/flytekit/unit/tools/test_aws.py @@ -4,6 +4,6 @@ def test_aws_s3_splitting(): - (bucket, key) = AwsS3Proxy._split_s3_path_to_bucket_and_key('s3://bucket/some/key') - assert bucket == 'bucket' - assert key == 'some/key' + (bucket, key) = AwsS3Proxy._split_s3_path_to_bucket_and_key("s3://bucket/some/key") + assert bucket == "bucket" + assert key == "some/key" diff --git a/tests/flytekit/unit/tools/test_lazy_loader.py b/tests/flytekit/unit/tools/test_lazy_loader.py index d0b0893af6..4b29208c93 100644 --- a/tests/flytekit/unit/tools/test_lazy_loader.py +++ b/tests/flytekit/unit/tools/test_lazy_loader.py @@ -1,18 +1,16 @@ from __future__ import absolute_import -from flytekit.tools import lazy_loader + import pytest import six +from flytekit.tools import lazy_loader + def test_lazy_loader_error_message(): lazy_mod = lazy_loader.lazy_load_module("made.up.module") - lazy_loader.LazyLoadPlugin( - "uninstalled_plugin", - [], - [lazy_mod] - ) + lazy_loader.LazyLoadPlugin("uninstalled_plugin", [], [lazy_mod]) with pytest.raises(ImportError) as e: lazy_mod.some_bad_attr - assert 'uninstalled_plugin' in six.text_type(e.value) - assert 'flytekit[all]' in six.text_type(e.value) + assert "uninstalled_plugin" in six.text_type(e.value) + assert "flytekit[all]" in six.text_type(e.value) diff --git a/tests/flytekit/unit/tools/test_module_loader.py b/tests/flytekit/unit/tools/test_module_loader.py index bc6a8af3f9..d0ef966c63 100644 --- a/tests/flytekit/unit/tools/test_module_loader.py +++ b/tests/flytekit/unit/tools/test_module_loader.py @@ -1,37 +1,38 @@ from __future__ import absolute_import + import os import sys -from flytekit.tools import module_loader from flytekit.common import utils as _utils +from flytekit.tools import module_loader def test_module_loading(): with _utils.AutoDeletingTempDir("mypackage") as pkg: path = pkg.name # Create directories - top_level = os.path.join(path, 'top') - middle_level = os.path.join(top_level, 'middle') - bottom_level = os.path.join(middle_level, 'bottom') + top_level = os.path.join(path, "top") + middle_level = os.path.join(top_level, "middle") + bottom_level = os.path.join(middle_level, "bottom") os.makedirs(bottom_level) # Create init files - with open(os.path.join(path, '__init__.py'), 'w'): + with open(os.path.join(path, "__init__.py"), "w"): pass - with open(os.path.join(top_level, '__init__.py'), 'w'): + with open(os.path.join(top_level, "__init__.py"), "w"): pass - with open(os.path.join(top_level, 'a.py'), 'w'): + with open(os.path.join(top_level, "a.py"), "w"): pass - with open(os.path.join(middle_level, '__init__.py'), 'w'): + with open(os.path.join(middle_level, "__init__.py"), "w"): pass - with open(os.path.join(middle_level, 'a.py'), 'w'): + with open(os.path.join(middle_level, "a.py"), "w"): pass - with open(os.path.join(bottom_level, '__init__.py'), 'w'): + with open(os.path.join(bottom_level, "__init__.py"), "w"): pass - with open(os.path.join(bottom_level, 'a.py'), 'w'): + with open(os.path.join(bottom_level, "a.py"), "w"): pass sys.path.append(path) # Not a sufficient test but passes for now - assert sum(1 for _ in module_loader.iterate_modules(['top'])) == 6 + assert sum(1 for _ in module_loader.iterate_modules(["top"])) == 6 diff --git a/tests/flytekit/unit/tools/test_subprocess.py b/tests/flytekit/unit/tools/test_subprocess.py index 7a7154c824..0b440e6ee9 100644 --- a/tests/flytekit/unit/tools/test_subprocess.py +++ b/tests/flytekit/unit/tools/test_subprocess.py @@ -13,23 +13,23 @@ def wait(self): @mock.patch.object(subprocess._subprocess, "Popen") def test_check_call(mock_call): mock_call.return_value = _MockProcess() - op = subprocess.check_call(["ls", "-l"], shell=True, env={'a': 'b'}, cwd="/tmp") + op = subprocess.check_call(["ls", "-l"], shell=True, env={"a": "b"}, cwd="/tmp") assert op == 0 mock_call.assert_called() assert mock_call.call_args[0][0] == ["ls", "-l"] - assert mock_call.call_args[1]['shell'] is True - assert mock_call.call_args[1]['env'] == {'a': 'b'} - assert mock_call.call_args[1]['cwd'] == "/tmp" + assert mock_call.call_args[1]["shell"] is True + assert mock_call.call_args[1]["env"] == {"a": "b"} + assert mock_call.call_args[1]["cwd"] == "/tmp" @mock.patch.object(subprocess._subprocess, "Popen") def test_check_call_shellex(mock_call): mock_call.return_value = _MockProcess() - op = subprocess.check_call("ls -l", shell=True, env={'a': 'b'}, cwd="/tmp") + op = subprocess.check_call("ls -l", shell=True, env={"a": "b"}, cwd="/tmp") assert op == 0 assert op == 0 mock_call.assert_called() assert mock_call.call_args[0][0] == ["ls", "-l"] - assert mock_call.call_args[1]['shell'] is True - assert mock_call.call_args[1]['env'] == {'a': 'b'} - assert mock_call.call_args[1]['cwd'] == "/tmp" + assert mock_call.call_args[1]["shell"] is True + assert mock_call.call_args[1]["env"] == {"a": "b"} + assert mock_call.call_args[1]["cwd"] == "/tmp" diff --git a/tests/flytekit/unit/type_engines/default/test_flyte_type_engine.py b/tests/flytekit/unit/type_engines/default/test_flyte_type_engine.py index 02fa4fb466..96d90ccfaa 100644 --- a/tests/flytekit/unit/type_engines/default/test_flyte_type_engine.py +++ b/tests/flytekit/unit/type_engines/default/test_flyte_type_engine.py @@ -1,19 +1,20 @@ from __future__ import absolute_import + +import pytest +from flyteidl.core import errors_pb2 as _errors_pb2 + from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.types import proto as _proto +from flytekit.models import literals as _literal_models +from flytekit.models import types as _type_models from flytekit.type_engines.default import flyte as _flyte_engine -from flytekit.models import types as _type_models, literals as _literal_models -from flyteidl.core import errors_pb2 as _errors_pb2 -import pytest def test_proto_from_literal_type(): sdk_type = _flyte_engine.FlyteDefaultTypeEngine().get_sdk_type_from_literal_type( _type_models.LiteralType( simple=_type_models.SimpleType.BINARY, - metadata={ - _proto.Protobuf.PB_FIELD_KEY: "flyteidl.core.errors_pb2.ContainerError" - } + metadata={_proto.Protobuf.PB_FIELD_KEY: "flyteidl.core.errors_pb2.ContainerError"}, ) ) @@ -25,9 +26,7 @@ def test_unloadable_module_from_literal_type(): _flyte_engine.FlyteDefaultTypeEngine().get_sdk_type_from_literal_type( _type_models.LiteralType( simple=_type_models.SimpleType.BINARY, - metadata={ - _proto.Protobuf.PB_FIELD_KEY: "flyteidl.core.errors_pb2_no_exist.ContainerError" - } + metadata={_proto.Protobuf.PB_FIELD_KEY: "flyteidl.core.errors_pb2_no_exist.ContainerError"}, ) ) @@ -37,9 +36,7 @@ def test_unloadable_proto_from_literal_type(): _flyte_engine.FlyteDefaultTypeEngine().get_sdk_type_from_literal_type( _type_models.LiteralType( simple=_type_models.SimpleType.BINARY, - metadata={ - _proto.Protobuf.PB_FIELD_KEY: "flyteidl.core.errors_pb2.ContainerErrorNoExist" - } + metadata={_proto.Protobuf.PB_FIELD_KEY: "flyteidl.core.errors_pb2.ContainerErrorNoExist"}, ) ) @@ -49,8 +46,7 @@ def test_infer_proto_from_literal(): _literal_models.Literal( scalar=_literal_models.Scalar( binary=_literal_models.Binary( - value="", - tag="{}{}".format(_proto.Protobuf.TAG_PREFIX, "flyteidl.core.errors_pb2.ContainerError") + value="", tag="{}{}".format(_proto.Protobuf.TAG_PREFIX, "flyteidl.core.errors_pb2.ContainerError",), ) ) ) diff --git a/tests/flytekit/unit/use_scenarios/unit_testing/test_blobs.py b/tests/flytekit/unit/use_scenarios/unit_testing/test_blobs.py index 529fa070f2..4ca56a1695 100644 --- a/tests/flytekit/unit/use_scenarios/unit_testing/test_blobs.py +++ b/tests/flytekit/unit/use_scenarios/unit_testing/test_blobs.py @@ -1,9 +1,9 @@ from __future__ import absolute_import -from flytekit.sdk.types import Types +from flytekit.common.utils import AutoDeletingTempDir from flytekit.sdk.tasks import inputs, outputs, python_task from flytekit.sdk.test_utils import flyte_test -from flytekit.common.utils import AutoDeletingTempDir +from flytekit.sdk.types import Types @flyte_test @@ -13,14 +13,14 @@ def test_create_blob_from_local_path(): def test_create_from_local_path(wf_params, a): with AutoDeletingTempDir("t") as tmp: tmp_name = tmp.get_named_tempfile("abc.blob") - with open(tmp_name, 'wb') as w: + with open(tmp_name, "wb") as w: w.write("Hello world".encode("utf-8")) a.set(tmp_name) out = test_create_from_local_path.unit_test() assert len(out) == 1 - with out['a'] as r: - assert r.read().decode('utf-8') == "Hello world" + with out["a"] as r: + assert r.read().decode("utf-8") == "Hello world" @flyte_test @@ -35,8 +35,8 @@ def test_write(wf_params, a): out = test_write.unit_test() assert len(out) == 1 - with out['a'] as r: - assert r.read().decode('utf-8') == "Hello world" + with out["a"] as r: + assert r.read().decode("utf-8") == "Hello world" @flyte_test @@ -53,13 +53,13 @@ def test_pass(wf_params, a, b): out = test_pass.unit_test(a=b) assert len(out) == 1 - with out['b'] as r: - assert r.read().decode('utf-8') == "Hello world" + with out["b"] as r: + assert r.read().decode("utf-8") == "Hello world" - out = test_pass.unit_test(a=out['b']) + out = test_pass.unit_test(a=out["b"]) assert len(out) == 1 - with out['b'] as r: - assert r.read().decode('utf-8') == "Hello world" + with out["b"] as r: + assert r.read().decode("utf-8") == "Hello world" @flyte_test @@ -68,18 +68,18 @@ def test_create_multipartblob_from_local_path(): @python_task def test_create_from_local_path(wf_params, a): with AutoDeletingTempDir("t") as tmp: - with open(tmp.get_named_tempfile("0"), 'wb') as w: + with open(tmp.get_named_tempfile("0"), "wb") as w: w.write("Hello world".encode("utf-8")) - with open(tmp.get_named_tempfile("1"), 'wb') as w: + with open(tmp.get_named_tempfile("1"), "wb") as w: w.write("Hello world2".encode("utf-8")) a.set(tmp.name) out = test_create_from_local_path.unit_test() assert len(out) == 1 - with out['a'] as r: + with out["a"] as r: assert len(r) == 2 - assert r[0].read().decode('utf-8') == "Hello world" - assert r[1].read().decode('utf-8') == "Hello world2" + assert r[0].read().decode("utf-8") == "Hello world" + assert r[1].read().decode("utf-8") == "Hello world2" @flyte_test @@ -96,10 +96,10 @@ def test_write(wf_params, a): out = test_write.unit_test() assert len(out) == 1 - with out['a'] as r: + with out["a"] as r: assert len(r) == 2 - assert r[0].read().decode('utf-8') == "Hello world" - assert r[1].read().decode('utf-8') == "Hello world2" + assert r[0].read().decode("utf-8") == "Hello world" + assert r[1].read().decode("utf-8") == "Hello world2" @flyte_test @@ -118,17 +118,17 @@ def test_pass(wf_params, a, b): out = test_pass.unit_test(a=b) assert len(out) == 1 - with out['b'] as r: + with out["b"] as r: assert len(r) == 2 - assert r[0].read().decode('utf-8') == "Hello world" - assert r[1].read().decode('utf-8') == "Hello world2" + assert r[0].read().decode("utf-8") == "Hello world" + assert r[1].read().decode("utf-8") == "Hello world2" - out = test_pass.unit_test(a=out['b']) + out = test_pass.unit_test(a=out["b"]) assert len(out) == 1 - with out['b'] as r: + with out["b"] as r: assert len(r) == 2 - assert r[0].read().decode('utf-8') == "Hello world" - assert r[1].read().decode('utf-8') == "Hello world2" + assert r[0].read().decode("utf-8") == "Hello world" + assert r[1].read().decode("utf-8") == "Hello world2" @flyte_test @@ -143,7 +143,7 @@ def test_write(wf_params, a): out = test_write.unit_test() assert len(out) == 1 - with out['a'] as r: + with out["a"] as r: assert r.read() == "Hello,world,hi" @@ -161,7 +161,7 @@ def test_write(wf_params, a): out = test_write.unit_test() assert len(out) == 1 - with out['a'] as r: + with out["a"] as r: assert len(r) == 2 assert r[0].read() == "Hello,world,1" assert r[1].read() == "Hello,world,2" diff --git a/tests/flytekit/unit/use_scenarios/unit_testing/test_hive_tasks.py b/tests/flytekit/unit/use_scenarios/unit_testing/test_hive_tasks.py index 7fa59b2c4f..1a4ed7cb16 100644 --- a/tests/flytekit/unit/use_scenarios/unit_testing/test_hive_tasks.py +++ b/tests/flytekit/unit/use_scenarios/unit_testing/test_hive_tasks.py @@ -1,7 +1,9 @@ from __future__ import absolute_import -from flytekit.sdk.tasks import hive_task + import pytest +from flytekit.sdk.tasks import hive_task + def test_no_queries(): @hive_task diff --git a/tests/flytekit/unit/use_scenarios/unit_testing/test_schemas.py b/tests/flytekit/unit/use_scenarios/unit_testing/test_schemas.py index 6336891199..93fed3eb3d 100644 --- a/tests/flytekit/unit/use_scenarios/unit_testing/test_schemas.py +++ b/tests/flytekit/unit/use_scenarios/unit_testing/test_schemas.py @@ -1,12 +1,13 @@ from __future__ import absolute_import -from flytekit.sdk.types import Types -from flytekit.sdk.tasks import inputs, outputs, python_task -from flytekit.sdk.test_utils import flyte_test -from flytekit.common.exceptions import user as _user_exceptions import pandas as pd import pytest +from flytekit.common.exceptions import user as _user_exceptions +from flytekit.sdk.tasks import inputs, outputs, python_task +from flytekit.sdk.test_utils import flyte_test +from flytekit.sdk.types import Types + @flyte_test def test_generic_schema(): @@ -24,77 +25,49 @@ def copy_task(wf_params, a, b): # Test generic copy and pass through a = Types.Schema()() with a as w: - w.write( - pd.DataFrame.from_dict( - { - 'a': [1, 2, 3], - 'b': [4.0, 5.0, 6.0] - } - ) - ) - w.write( - pd.DataFrame.from_dict( - { - 'a': [3, 2, 1], - 'b': [6.0, 5.0, 4.0] - } - ) - ) + w.write(pd.DataFrame.from_dict({"a": [1, 2, 3], "b": [4.0, 5.0, 6.0]})) + w.write(pd.DataFrame.from_dict({"a": [3, 2, 1], "b": [6.0, 5.0, 4.0]})) outs = copy_task.unit_test(a=a) - with outs['b'] as r: + with outs["b"] as r: df = r.read() - assert list(df['a']) == [1, 2, 3] - assert list(df['b']) == [4.0, 5.0, 6.0] + assert list(df["a"]) == [1, 2, 3] + assert list(df["b"]) == [4.0, 5.0, 6.0] df = r.read() - assert list(df['a']) == [3, 2, 1] - assert list(df['b']) == [6.0, 5.0, 4.0] + assert list(df["a"]) == [3, 2, 1] + assert list(df["b"]) == [6.0, 5.0, 4.0] assert r.read() is None # Test typed copy and pass through - a = Types.Schema([('a', Types.Integer), ('b', Types.Float)])() + a = Types.Schema([("a", Types.Integer), ("b", Types.Float)])() with a as w: - w.write( - pd.DataFrame.from_dict( - { - 'a': [1, 2, 3], - 'b': [4.0, 5.0, 6.0] - } - ) - ) - w.write( - pd.DataFrame.from_dict( - { - 'a': [3, 2, 1], - 'b': [6.0, 5.0, 4.0] - } - ) - ) + w.write(pd.DataFrame.from_dict({"a": [1, 2, 3], "b": [4.0, 5.0, 6.0]})) + w.write(pd.DataFrame.from_dict({"a": [3, 2, 1], "b": [6.0, 5.0, 4.0]})) outs = copy_task.unit_test(a=a) - with outs['b'] as r: + with outs["b"] as r: df = r.read() - assert list(df['a']) == [1, 2, 3] - assert list(df['b']) == [4.0, 5.0, 6.0] + assert list(df["a"]) == [1, 2, 3] + assert list(df["b"]) == [4.0, 5.0, 6.0] df = r.read() - assert list(df['a']) == [3, 2, 1] - assert list(df['b']) == [6.0, 5.0, 4.0] + assert list(df["a"]) == [3, 2, 1] + assert list(df["b"]) == [6.0, 5.0, 4.0] assert r.read() is None @flyte_test def test_typed_schema(): - @inputs(a=Types.Schema([('a', Types.Integer), ('b', Types.Float)])) - @outputs(b=Types.Schema([('a', Types.Integer), ('b', Types.Float)])) + @inputs(a=Types.Schema([("a", Types.Integer), ("b", Types.Float)])) + @outputs(b=Types.Schema([("a", Types.Integer), ("b", Types.Float)])) @python_task def copy_task(wf_params, a, b): - out = Types.Schema([('a', Types.Integer), ('b', Types.Float)])() + out = Types.Schema([("a", Types.Integer), ("b", Types.Float)])() with a as r: with out as w: for df in r.iter_chunks(): @@ -102,57 +75,29 @@ def copy_task(wf_params, a, b): b.set(out) # Test typed copy and pass through - a = Types.Schema([('a', Types.Integer), ('b', Types.Float)])() + a = Types.Schema([("a", Types.Integer), ("b", Types.Float)])() with a as w: - w.write( - pd.DataFrame.from_dict( - { - 'a': [1, 2, 3], - 'b': [4.0, 5.0, 6.0] - } - ) - ) - w.write( - pd.DataFrame.from_dict( - { - 'a': [3, 2, 1], - 'b': [6.0, 5.0, 4.0] - } - ) - ) + w.write(pd.DataFrame.from_dict({"a": [1, 2, 3], "b": [4.0, 5.0, 6.0]})) + w.write(pd.DataFrame.from_dict({"a": [3, 2, 1], "b": [6.0, 5.0, 4.0]})) outs = copy_task.unit_test(a=a) - with outs['b'] as r: + with outs["b"] as r: df = r.read() - assert list(df['a']) == [1, 2, 3] - assert list(df['b']) == [4.0, 5.0, 6.0] + assert list(df["a"]) == [1, 2, 3] + assert list(df["b"]) == [4.0, 5.0, 6.0] df = r.read() - assert list(df['a']) == [3, 2, 1] - assert list(df['b']) == [6.0, 5.0, 4.0] + assert list(df["a"]) == [3, 2, 1] + assert list(df["b"]) == [6.0, 5.0, 4.0] assert r.read() is None # Test untyped failure a = Types.Schema()() with a as w: - w.write( - pd.DataFrame.from_dict( - { - 'a': [1, 2, 3], - 'b': [4.0, 5.0, 6.0] - } - ) - ) - w.write( - pd.DataFrame.from_dict( - { - 'a': [3, 2, 1], - 'b': [6.0, 5.0, 4.0] - } - ) - ) + w.write(pd.DataFrame.from_dict({"a": [1, 2, 3], "b": [4.0, 5.0, 6.0]})) + w.write(pd.DataFrame.from_dict({"a": [3, 2, 1], "b": [6.0, 5.0, 4.0]})) with pytest.raises(_user_exceptions.FlyteTypeException): copy_task.unit_test(a=a) @@ -160,34 +105,27 @@ def copy_task(wf_params, a, b): @flyte_test def test_subset_of_columns(): - @outputs(a=Types.Schema([('a', Types.Integer), ('b', Types.String)])) + @outputs(a=Types.Schema([("a", Types.Integer), ("b", Types.String)])) @python_task() def source(wf_params, a): - out = Types.Schema([('a', Types.Integer), ('b', Types.String)])() + out = Types.Schema([("a", Types.Integer), ("b", Types.String)])() with out as writer: - writer.write( - pd.DataFrame.from_dict( - { - 'a': [1, 2, 3, 4, 5], - 'b': ['a', 'b', 'c', 'd', 'e'] - } - ) - ) + writer.write(pd.DataFrame.from_dict({"a": [1, 2, 3, 4, 5], "b": ["a", "b", "c", "d", "e"]})) a.set(out) - @inputs(a=Types.Schema([('a', Types.Integer)])) + @inputs(a=Types.Schema([("a", Types.Integer)])) @python_task() def sink(wf_params, a): with a as reader: df = reader.read(concat=True) assert len(df.columns.values) == 1 - assert df['a'].tolist() == [1, 2, 3, 4, 5] + assert df["a"].tolist() == [1, 2, 3, 4, 5] with a as reader: df = reader.read(truncate_extra_columns=False) - assert df.columns.values.tolist() == ['a', 'b'] - assert df['a'].tolist() == [1, 2, 3, 4, 5] - assert df['b'].tolist() == ['a', 'b', 'c', 'd', 'e'] + assert df.columns.values.tolist() == ["a", "b"] + assert df["a"].tolist() == [1, 2, 3, 4, 5] + assert df["b"].tolist() == ["a", "b", "c", "d", "e"] o = source.unit_test() sink.unit_test(**o) @@ -200,4 +138,4 @@ def test_no_output_set(): def null_set(wf_params, a): pass - assert null_set.unit_test()['a'] is None + assert null_set.unit_test()["a"] is None