diff --git a/dev-requirements.in b/dev-requirements.in index d8950c5e52..6fa5ebb080 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -1,7 +1,7 @@ -c requirements.txt -black==19.10b0 -coverage +black +coverage[toml] flake8 flake8-black flake8-isort diff --git a/dev-requirements.txt b/dev-requirements.txt index 5df2507807..5acc7855b8 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -11,9 +11,8 @@ appdirs==1.4.4 attrs==20.3.0 # via # -c requirements.txt - # black # pytest -black==19.10b0 +black==20.8b1 # via # -c requirements.txt # -r dev-requirements.in @@ -22,7 +21,7 @@ click==7.1.2 # via # -c requirements.txt # black -coverage==5.5 +coverage[toml]==5.5 # via -r dev-requirements.in flake8-black==0.2.1 # via -r dev-requirements.in @@ -46,6 +45,7 @@ mock==4.0.3 mypy-extensions==0.4.3 # via # -c requirements.txt + # black # mypy mypy==0.812 # via -r dev-requirements.in @@ -83,6 +83,7 @@ toml==0.10.2 # via # -c requirements.txt # black + # coverage # pytest typed-ast==1.4.2 # via @@ -92,4 +93,5 @@ typed-ast==1.4.2 typing-extensions==3.7.4.3 # via # -c requirements.txt + # black # mypy diff --git a/doc-requirements.txt b/doc-requirements.txt index 05d748a71c..963d54ce5e 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -12,17 +12,12 @@ ansiwrap==0.8.4 # via papermill appdirs==1.4.4 # via black -appnope==0.1.2 - # via - # ipykernel - # ipython astroid==2.5.1 # via sphinx-autoapi async-generator==1.10 # via nbclient attrs==20.3.0 # via - # black # jsonschema # scantree babel==2.9.0 @@ -35,15 +30,13 @@ beautifulsoup4==4.9.3 # via # sphinx-code-include # sphinx-material -black==19.10b0 - # via - # flytekit - # papermill +black==20.8b1 + # via papermill bleach==3.3.0 # via nbconvert -boto3==1.17.36 +boto3==1.17.39 # via sagemaker-training -botocore==1.20.36 +botocore==1.20.39 # via # boto3 # s3transfer @@ -64,10 +57,11 @@ click==7.1.2 # papermill croniter==1.0.10 # via flytekit -cryptography==3.4.6 +cryptography==3.4.7 # via # -r doc-requirements.in # paramiko + # secretstorage css-html-js-minify==2.5.5 # via sphinx-material dataclasses-json==0.5.2 @@ -120,6 +114,10 @@ ipython==7.21.0 # via ipykernel jedi==0.18.0 # via ipython +jeepney==0.6.0 + # via + # keyring + # secretstorage jinja2==2.11.3 # via # nbconvert @@ -161,7 +159,9 @@ marshmallow==3.10.0 mistune==0.8.4 # via nbconvert mypy-extensions==0.4.3 - # via typing-inspect + # via + # black + # typing-inspect natsort==7.1.1 # via flytekit nbclient==0.5.3 @@ -287,6 +287,8 @@ scantree==0.0.1 # via dirhash scipy==1.6.2 # via sagemaker-training +secretstorage==3.3.1 + # via keyring six==1.15.0 # via # bcrypt @@ -377,7 +379,9 @@ traitlets==5.0.5 typed-ast==1.4.2 # via black typing-extensions==3.7.4.3 - # via typing-inspect + # via + # black + # typing-inspect typing-inspect==0.6.0 # via dataclasses-json unidecode==1.2.0 diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index c3b9ed76d3..86c94fa2d7 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -126,7 +126,11 @@ def _dispatch_execute(ctx: FlyteContext, task_def: PythonTask, inputs_path: str, _logging.error("!! Begin Unknown System Error Captured by Flyte !!") exc_str = _traceback.format_exc() output_file_dict[_constants.ERROR_FILE_NAME] = _error_models.ErrorDocument( - _error_models.ContainerError("SYSTEM:Unknown", exc_str, _error_models.ContainerError.Kind.RECOVERABLE,) + _error_models.ContainerError( + "SYSTEM:Unknown", + exc_str, + _error_models.ContainerError.Kind.RECOVERABLE, + ) ) _logging.error(exc_str) _logging.error("!! End Error Captured by Flyte !!") @@ -185,11 +189,13 @@ def _handle_annotated_task(task_def: PythonTask, inputs: str, output_prefix: str if cloud_provider == _constants.CloudProvider.AWS: file_access = _data_proxy.FileAccessProvider( - local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get(), remote_proxy=_s3proxy.AwsS3Proxy(raw_output_data_prefix), + local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get(), + remote_proxy=_s3proxy.AwsS3Proxy(raw_output_data_prefix), ) elif cloud_provider == _constants.CloudProvider.GCP: file_access = _data_proxy.FileAccessProvider( - local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get(), remote_proxy=_gcs_proxy.GCSProxy(raw_output_data_prefix), + local_sandbox_dir=_sdk_config.LOCAL_SANDBOX.get(), + remote_proxy=_gcs_proxy.GCSProxy(raw_output_data_prefix), ) elif cloud_provider == _constants.CloudProvider.LOCAL: # A fake remote using the local disk will automatically be created @@ -353,7 +359,9 @@ def _pass_through(): @_click.option("--test", is_flag=True) @_click.option("--resolver", required=False) @_click.argument( - "resolver-args", type=_click.UNPROCESSED, nargs=-1, + "resolver-args", + type=_click.UNPROCESSED, + nargs=-1, ) def execute_task_cmd( task_module, task_name, inputs, output_prefix, raw_output_data_prefix, test, resolver, resolver_args @@ -408,15 +416,29 @@ def fast_execute_task_cmd(additional_distribution, dest_dir, task_execute_cmd): @_click.option("--test", is_flag=True) @_click.option("--resolver", required=True) @_click.argument( - "resolver-args", type=_click.UNPROCESSED, nargs=-1, + "resolver-args", + type=_click.UNPROCESSED, + nargs=-1, ) def map_execute_task_cmd( - inputs, output_prefix, raw_output_data_prefix, max_concurrency, test, resolver, resolver_args, + inputs, + output_prefix, + raw_output_data_prefix, + max_concurrency, + test, + resolver, + resolver_args, ): _click.echo(_utils.get_version_message()) _execute_map_task( - inputs, output_prefix, raw_output_data_prefix, max_concurrency, test, resolver, resolver_args, + inputs, + output_prefix, + raw_output_data_prefix, + max_concurrency, + test, + resolver, + resolver_args, ) diff --git a/flytekit/clients/friendly.py b/flytekit/clients/friendly.py index 884fd8c22c..a5128f7e6d 100644 --- a/flytekit/clients/friendly.py +++ b/flytekit/clients/friendly.py @@ -316,7 +316,8 @@ def create_launch_plan(self, launch_plan_identifer, launch_plan_spec): """ super(SynchronousFlyteClient, self).create_launch_plan( _launch_plan_pb2.LaunchPlanCreateRequest( - id=launch_plan_identifer.to_flyte_idl(), spec=launch_plan_spec.to_flyte_idl(), + id=launch_plan_identifer.to_flyte_idl(), + spec=launch_plan_spec.to_flyte_idl(), ) ) @@ -506,7 +507,9 @@ def update_named_entity(self, resource_type, id, metadata): """ super(SynchronousFlyteClient, self).update_named_entity( _common_pb2.NamedEntityUpdateRequest( - resource_type=resource_type, id=id.to_flyte_idl(), metadata=metadata.to_flyte_idl(), + resource_type=resource_type, + id=id.to_flyte_idl(), + metadata=metadata.to_flyte_idl(), ) ) @@ -661,7 +664,12 @@ def get_node_execution_data(self, node_execution_identifier): ) def list_node_executions( - self, workflow_execution_identifier, limit=100, token=None, filters=None, sort_by=None, + self, + workflow_execution_identifier, + limit=100, + token=None, + filters=None, + sort_by=None, ): """ TODO: Comment @@ -689,7 +697,12 @@ def list_node_executions( ) def list_node_executions_for_task_paginated( - self, task_execution_identifier, limit=100, token=None, filters=None, sort_by=None, + self, + task_execution_identifier, + limit=100, + token=None, + filters=None, + sort_by=None, ): """ This returns nodes spawned by a specific task execution. This is generally from things like dynamic tasks. @@ -747,7 +760,12 @@ def get_task_execution_data(self, task_execution_identifier): ) def list_task_executions_paginated( - self, node_execution_identifier, limit=100, token=None, filters=None, sort_by=None, + self, + node_execution_identifier, + limit=100, + token=None, + filters=None, + sort_by=None, ): """ :param flytekit.models.core.identifier.NodeExecutionIdentifier node_execution_identifier: @@ -786,7 +804,9 @@ def register_project(self, project): :rtype: flyteidl.admin.project_pb2.ProjectRegisterResponse """ super(SynchronousFlyteClient, self).register_project( - _project_pb2.ProjectRegisterRequest(project=project.to_flyte_idl(),) + _project_pb2.ProjectRegisterRequest( + project=project.to_flyte_idl(), + ) ) def update_project(self, project): @@ -853,7 +873,9 @@ def update_project_domain_attributes(self, project, domain, matching_attributes) super(SynchronousFlyteClient, self).update_project_domain_attributes( _project_domain_attributes_pb2.ProjectDomainAttributesUpdateRequest( attributes=_project_domain_attributes_pb2.ProjectDomainAttributes( - project=project, domain=domain, matching_attributes=matching_attributes.to_flyte_idl(), + project=project, + domain=domain, + matching_attributes=matching_attributes.to_flyte_idl(), ) ) ) @@ -888,7 +910,9 @@ def get_project_domain_attributes(self, project, domain, resource_type): """ return super(SynchronousFlyteClient, self).get_project_domain_attributes( _project_domain_attributes_pb2.ProjectDomainAttributesGetRequest( - project=project, domain=domain, resource_type=resource_type, + project=project, + domain=domain, + resource_type=resource_type, ) ) @@ -903,7 +927,10 @@ def get_workflow_attributes(self, project, domain, workflow, resource_type): """ return super(SynchronousFlyteClient, self).get_workflow_attributes( _workflow_attributes_pb2.WorkflowAttributesGetRequest( - project=project, domain=domain, workflow=workflow, resource_type=resource_type, + project=project, + domain=domain, + workflow=workflow, + resource_type=resource_type, ) ) @@ -914,5 +941,7 @@ def list_matchable_attributes(self, resource_type): :return: """ return super(SynchronousFlyteClient, self).list_matchable_attributes( - _matchable_resource_pb2.ListMatchableAttributesRequest(resource_type=resource_type,) + _matchable_resource_pb2.ListMatchableAttributesRequest( + resource_type=resource_type, + ) ) diff --git a/flytekit/clients/helpers.py b/flytekit/clients/helpers.py index 75b2232636..4d8a7912a4 100644 --- a/flytekit/clients/helpers.py +++ b/flytekit/clients/helpers.py @@ -1,5 +1,9 @@ def iterate_node_executions( - client, workflow_execution_identifier=None, task_execution_identifier=None, limit=None, filters=None, + client, + workflow_execution_identifier=None, + task_execution_identifier=None, + limit=None, + filters=None, ): """ This returns a generator for node executions. @@ -25,7 +29,10 @@ def iterate_node_executions( ) else: node_execs, next_token = client.list_node_executions_for_task_paginated( - task_execution_identifier=task_execution_identifier, limit=num_to_fetch, token=token, filters=filters, + task_execution_identifier=task_execution_identifier, + limit=num_to_fetch, + token=token, + filters=filters, ) for n in node_execs: counter += 1 @@ -53,7 +60,10 @@ def iterate_task_executions(client, node_execution_identifier, limit=None, filte counter = 0 while True: task_execs, next_token = client.list_task_executions_paginated( - node_execution_identifier=node_execution_identifier, limit=num_to_fetch, token=token, filters=filters, + node_execution_identifier=node_execution_identifier, + limit=num_to_fetch, + token=token, + filters=filters, ) for t in task_execs: counter += 1 diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py index cb10666151..18e923ff62 100644 --- a/flytekit/clients/raw.py +++ b/flytekit/clients/raw.py @@ -151,7 +151,9 @@ def __init__(self, url, insecure=False, credentials=None, options=None): self._channel = _insecure_channel(url, options=list((options or {}).items())) else: self._channel = _secure_channel( - url, credentials or _ssl_channel_credentials(), options=list((options or {}).items()), + url, + credentials or _ssl_channel_credentials(), + options=list((options or {}).items()), ) self._stub = _admin_service.AdminServiceStub(self._channel) self._metadata = None @@ -165,7 +167,12 @@ def url(self) -> str: def set_access_token(self, access_token): # Always set the header to lower-case regardless of what the config is. The grpc libraries that Admin uses # to parse the metadata don't change the metadata, but they do automatically lower the key you're looking for. - self._metadata = [(_creds_config.AUTHORIZATION_METADATA_KEY.get().lower(), "Bearer {}".format(access_token),)] + self._metadata = [ + ( + _creds_config.AUTHORIZATION_METADATA_KEY.get().lower(), + "Bearer {}".format(access_token), + ) + ] def force_auth_flow(self): refresh_handler_fn = _get_refresh_handler(_creds_config.AUTH_MODE.get()) diff --git a/flytekit/clis/auth/auth.py b/flytekit/clis/auth/auth.py index 2afc4644e7..43ce73ee0a 100644 --- a/flytekit/clis/auth/auth.py +++ b/flytekit/clis/auth/auth.py @@ -118,7 +118,12 @@ class OAuthHTTPServer(_BaseHTTPServer.HTTPServer): """ def __init__( - self, server_address, RequestHandlerClass, bind_and_activate=True, redirect_path=None, queue=None, + self, + server_address, + RequestHandlerClass, + bind_and_activate=True, + redirect_path=None, + queue=None, ): _BaseHTTPServer.HTTPServer.__init__(self, server_address, RequestHandlerClass, bind_and_activate) self._redirect_path = redirect_path @@ -233,7 +238,10 @@ def request_access_token(self, auth_code): {"code": auth_code.code, "code_verifier": self._code_verifier, "grant_type": "authorization_code"} ) resp = _requests.post( - url=self._token_endpoint, data=self._params, headers=self._headers, allow_redirects=False, + url=self._token_endpoint, + data=self._params, + headers=self._headers, + allow_redirects=False, ) if resp.status_code != _StatusCodes.OK: # TODO: handle expected (?) error cases: diff --git a/flytekit/clis/auth/discovery.py b/flytekit/clis/auth/discovery.py index d661f6302e..5134eab974 100644 --- a/flytekit/clis/auth/discovery.py +++ b/flytekit/clis/auth/discovery.py @@ -46,7 +46,9 @@ def authorization_endpoints(self): def get_authorization_endpoints(self): if self.authorization_endpoints is not None: return self.authorization_endpoints - resp = _requests.get(url=self._discovery_url,) + resp = _requests.get( + url=self._discovery_url, + ) response_body = resp.json() diff --git a/flytekit/clis/flyte_cli/main.py b/flytekit/clis/flyte_cli/main.py index 4d8cefea45..74f082b42e 100644 --- a/flytekit/clis/flyte_cli/main.py +++ b/flytekit/clis/flyte_cli/main.py @@ -114,7 +114,11 @@ def _get_io_string(literal_map, verbose=False): if value_dict: return "\n" + "\n".join( "{:30}: {}".format( - k, _prefix_lines("{:30} ".format(""), v.verbose_string() if verbose else v.short_string(),), + k, + _prefix_lines( + "{:30} ".format(""), + v.verbose_string() if verbose else v.short_string(), + ), ) for k, v in _six.iteritems(value_dict) ) @@ -203,7 +207,10 @@ def _secho_node_execution_status(status, nl=True): fg = "blue" _click.secho( - "{:10} ".format(_tt(_core_execution_models.NodeExecutionPhase.enum_to_string(status))), bold=True, fg=fg, nl=nl, + "{:10} ".format(_tt(_core_execution_models.NodeExecutionPhase.enum_to_string(status))), + bold=True, + fg=fg, + nl=nl, ) @@ -228,7 +235,10 @@ def _secho_task_execution_status(status, nl=True): fg = "blue" _click.secho( - "{:10} ".format(_tt(_core_execution_models.TaskExecutionPhase.enum_to_string(status))), bold=True, fg=fg, nl=nl, + "{:10} ".format(_tt(_core_execution_models.TaskExecutionPhase.enum_to_string(status))), + bold=True, + fg=fg, + nl=nl, ) @@ -245,7 +255,8 @@ def _secho_one_execution(ex, urns_only): _secho_workflow_status(ex.closure.phase) else: _click.echo( - "{:100}".format(_tt(_identifier.WorkflowExecutionIdentifier.promote_from_model(ex.id))), nl=True, + "{:100}".format(_tt(_identifier.WorkflowExecutionIdentifier.promote_from_model(ex.id))), + nl=True, ) @@ -305,19 +316,33 @@ def _render_schedule_expr(lp): _project_option = _click.option(*_PROJECT_FLAGS, required=True, help="The project namespace to query.") _optional_project_option = _click.option( - *_PROJECT_FLAGS, required=False, default=None, help="[Optional] The project namespace to query.", + *_PROJECT_FLAGS, + required=False, + default=None, + help="[Optional] The project namespace to query.", ) _domain_option = _click.option(*_DOMAIN_FLAGS, required=True, help="The domain namespace to query.") _optional_domain_option = _click.option( - *_DOMAIN_FLAGS, required=False, default=None, help="[Optional] The domain namespace to query.", + *_DOMAIN_FLAGS, + required=False, + default=None, + help="[Optional] The domain namespace to query.", ) _name_option = _click.option(*_NAME_FLAGS, required=True, help="The name to query.") _optional_name_option = _click.option( - *_NAME_FLAGS, required=False, type=str, default=None, help="[Optional] The name to query.", + *_NAME_FLAGS, + required=False, + type=str, + default=None, + help="[Optional] The name to query.", ) _principal_option = _click.option(*_PRINCIPAL_FLAGS, required=True, help="Your team name, or your name") _optional_principal_option = _click.option( - *_PRINCIPAL_FLAGS, required=False, type=str, default=None, help="[Optional] Your team name, or your name", + *_PRINCIPAL_FLAGS, + required=False, + type=str, + default=None, + help="[Optional] Your team name, or your name", ) _insecure_option = _click.option(*_INSECURE_FLAGS, is_flag=True, required=True, help="Do not use SSL") _urn_option = _click.option("-u", "--urn", required=True, help="The unique identifier for an entity.") @@ -340,10 +365,19 @@ def _render_schedule_expr(lp): help="Pagination token from which to start listing in the list of results.", ) _limit_option = _click.option( - "-l", "--limit", required=False, default=100, type=int, help="Maximum number of results to return for this call.", + "-l", + "--limit", + required=False, + default=100, + type=int, + help="Maximum number of results to return for this call.", ) _show_all_option = _click.option( - "-a", "--show-all", is_flag=True, default=False, help="Set this flag to page through and list all results.", + "-a", + "--show-all", + is_flag=True, + default=False, + help="Set this flag to page through and list all results.", ) # TODO: Provide documentation on filter format _filter_option = _click.option( @@ -367,10 +401,15 @@ def _render_schedule_expr(lp): help="The state change to apply to a named entity", ) _named_entity_description_option = _click.option( - "--description", required=False, type=str, help="Concise description for the entity.", + "--description", + required=False, + type=str, + help="Concise description for the entity.", ) _sort_by_option = _click.option( - "--sort-by", required=False, help="Provide an entity field to be sorted. i.e. asc(name) or desc(name)", + "--sort-by", + required=False, + help="Provide an entity field to be sorted. i.e. asc(name) or desc(name)", ) _show_io_option = _click.option( "--show-io", @@ -380,7 +419,10 @@ def _render_schedule_expr(lp): " inputs and outputs.", ) _verbose_option = _click.option( - "--verbose", is_flag=True, default=False, help="Set this flag to view the full textual description of all fields.", + "--verbose", + is_flag=True, + default=False, + help="Set this flag to view the full textual description of all fields.", ) _filename_option = _click.option("-f", "--filename", required=True, help="File path of pb file") @@ -391,7 +433,10 @@ def _render_schedule_expr(lp): help="Dot (.) separated path to Python IDL class. (e.g. flyteidl.core.workflow_closure_pb2.WorkflowClosure)", ) _cause_option = _click.option( - "-c", "--cause", required=True, help="The message signaling the cause of the termination of the execution(s)", + "-c", + "--cause", + required=True, + help="The message signaling the cause of the termination of the execution(s)", ) _optional_urns_only_option = _click.option( "--urns-only", @@ -401,13 +446,25 @@ def _render_schedule_expr(lp): help="[Optional] Set the flag if you want to output the urn(s) only. Setting this will override the verbose flag", ) _project_identifier_option = _click.option( - "-p", "--identifier", required=True, type=str, help="Unique identifier for the project.", + "-p", + "--identifier", + required=True, + type=str, + help="Unique identifier for the project.", ) _project_name_option = _click.option( - "-n", "--name", required=True, type=str, help="The human-readable name for the project.", + "-n", + "--name", + required=True, + type=str, + help="The human-readable name for the project.", ) _project_description_option = _click.option( - "-d", "--description", required=True, type=str, help="Concise description for the project.", + "-d", + "--description", + required=True, + type=str, + help="Concise description for the project.", ) _watch_option = _click.option( "-w", @@ -428,7 +485,11 @@ def _render_schedule_expr(lp): _output_location_prefix_option = _click.option( "-o", "--output-location-prefix", help="Custom output location prefix for offloaded types (files/schemas)" ) -_files_argument = _click.argument("files", type=_click.Path(exists=True), nargs=-1,) +_files_argument = _click.argument( + "files", + type=_click.Path(exists=True), + nargs=-1, +) class _FlyteSubCommand(_click.Command): @@ -612,7 +673,12 @@ def list_task_versions(project, domain, name, host, insecure, token, limit, show sort_by=_admin_common.Sort.from_python_std(sort_by) if sort_by else None, ) for t in task_list: - _click.echo("{:50} {:40}".format(_tt(t.id.version), _tt(_identifier.Identifier.promote_from_model(t.id)),)) + _click.echo( + "{:50} {:40}".format( + _tt(t.id.version), + _tt(_identifier.Identifier.promote_from_model(t.id)), + ) + ) if show_all is not True: if next_token: @@ -770,7 +836,12 @@ def list_workflow_versions(project, domain, name, host, insecure, token, limit, sort_by=_admin_common.Sort.from_python_std(sort_by) if sort_by else None, ) for w in wf_list: - _click.echo("{:50} {:40}".format(_tt(w.id.version), _tt(_identifier.Identifier.promote_from_model(w.id)),)) + _click.echo( + "{:50} {:40}".format( + _tt(w.id.version), + _tt(_identifier.Identifier.promote_from_model(w.id)), + ) + ) if show_all is not True: if next_token: @@ -912,7 +983,17 @@ def list_active_launch_plans(project, domain, host, insecure, token, limit, show @_sort_by_option @_optional_urns_only_option def list_launch_plan_versions( - project, domain, name, host, insecure, token, limit, show_all, filter, sort_by, urns_only, + project, + domain, + name, + host, + insecure, + token, + limit, + show_all, + filter, + sort_by, + urns_only, ): """ List the versions of all the launch plans under the scope specified by {project, domain}. @@ -937,7 +1018,10 @@ def list_launch_plan_versions( _click.echo(_tt(_identifier.Identifier.promote_from_model(l.id))) else: _click.echo( - "{:50} {:80} ".format(_tt(l.id.version), _tt(_identifier.Identifier.promote_from_model(l.id)),), + "{:50} {:80} ".format( + _tt(l.id.version), + _tt(_identifier.Identifier.promote_from_model(l.id)), + ), nl=False, ) if l.spec.entity_metadata.schedule is not None and ( @@ -1303,21 +1387,27 @@ def _get_io(node_executions, wf_execution, show_io, verbose): def _render_workflow_execution(wf_execution, uri_to_message_map, show_io, verbose): _click.echo( "\nExecution {project}:{domain}:{name}\n".format( - project=_tt(wf_execution.id.project), domain=_tt(wf_execution.id.domain), name=_tt(wf_execution.id.name), + project=_tt(wf_execution.id.project), + domain=_tt(wf_execution.id.domain), + name=_tt(wf_execution.id.name), ) ) _click.echo("\t{:15} ".format("State:"), nl=False) _secho_workflow_status(wf_execution.closure.phase) _click.echo( "\t{:15} {}".format( - "Launch Plan:", _tt(_identifier.Identifier.promote_from_model(wf_execution.spec.launch_plan)), + "Launch Plan:", + _tt(_identifier.Identifier.promote_from_model(wf_execution.spec.launch_plan)), ) ) if show_io: _click.secho( "\tInputs: {}\n".format( - _prefix_lines("\t\t", _get_io_string(wf_execution.closure.computed_inputs, verbose=verbose),) + _prefix_lines( + "\t\t", + _get_io_string(wf_execution.closure.computed_inputs, verbose=verbose), + ) ) ) if wf_execution.closure.outputs is not None: @@ -1326,14 +1416,20 @@ def _render_workflow_execution(wf_execution, uri_to_message_map, show_io, verbos "\tOutputs: {}\n".format( _prefix_lines( "\t\t", - uri_to_message_map.get(wf_execution.closure.outputs.uri, wf_execution.closure.outputs.uri,), + uri_to_message_map.get( + wf_execution.closure.outputs.uri, + wf_execution.closure.outputs.uri, + ), ) ) ) elif wf_execution.closure.outputs.values is not None: _click.secho( "\tOutputs: {}\n".format( - _prefix_lines("\t\t", _get_io_string(wf_execution.closure.outputs.values, verbose=verbose),) + _prefix_lines( + "\t\t", + _get_io_string(wf_execution.closure.outputs.values, verbose=verbose), + ) ) ) else: @@ -1341,7 +1437,9 @@ def _render_workflow_execution(wf_execution, uri_to_message_map, show_io, verbos if wf_execution.closure.error is not None: _click.secho( - _prefix_lines("\t", _render_error(wf_execution.closure.error)), fg="red", bold=True, + _prefix_lines("\t", _render_error(wf_execution.closure.error)), + fg="red", + bold=True, ) @@ -1360,7 +1458,9 @@ def _get_all_task_executions_for_node(client, node_execution_identifier): while True: num_to_fetch = 100 task_execs, next_token = client.list_task_executions_paginated( - node_execution_identifier=node_execution_identifier, limit=num_to_fetch, token=token, + node_execution_identifier=node_execution_identifier, + limit=num_to_fetch, + token=token, ) for te in task_execs: fetched_task_execs.append(te) @@ -1379,11 +1479,15 @@ def _get_all_node_executions(client, workflow_execution_identifier=None, task_ex num_to_fetch = 100 if workflow_execution_identifier: node_execs, next_token = client.list_node_executions( - workflow_execution_identifier=workflow_execution_identifier, limit=num_to_fetch, token=token, + workflow_execution_identifier=workflow_execution_identifier, + limit=num_to_fetch, + token=token, ) else: node_execs, next_token = client.list_node_executions_for_task_paginated( - task_execution_identifier=task_execution_identifier, limit=num_to_fetch, token=token, + task_execution_identifier=task_execution_identifier, + limit=num_to_fetch, + token=token, ) all_node_execs.extend(node_execs) if not next_token: @@ -1412,7 +1516,11 @@ def _render_node_executions(client, node_execs, show_io, verbose, host, insecure _click.echo("\t\t\t{:15} {:60} ".format("Duration:", _tt(ne.closure.duration))) _click.echo( "\t\t\t{:15} {}".format( - "Input:", _prefix_lines("\t\t\t{:15} ".format(""), uri_to_message_map.get(ne.input_uri, ne.input_uri),), + "Input:", + _prefix_lines( + "\t\t\t{:15} ".format(""), + uri_to_message_map.get(ne.input_uri, ne.input_uri), + ), ) ) if ne.closure.output_uri: @@ -1420,13 +1528,16 @@ def _render_node_executions(client, node_execs, show_io, verbose, host, insecure "\t\t\t{:15} {}".format( "Output:", _prefix_lines( - "\t\t\t{:15} ".format(""), uri_to_message_map.get(ne.closure.output_uri, ne.closure.output_uri), + "\t\t\t{:15} ".format(""), + uri_to_message_map.get(ne.closure.output_uri, ne.closure.output_uri), ), ) ) if ne.closure.error is not None: _click.secho( - _prefix_lines("\t\t\t", _render_error(ne.closure.error)), bold=True, fg="red", + _prefix_lines("\t\t\t", _render_error(ne.closure.error)), + bold=True, + fg="red", ) task_executions = node_executions_to_task_executions.get(ne.id, []) @@ -1450,7 +1561,9 @@ def _render_node_executions(client, node_execs, show_io, verbose, host, insecure if te.closure.error is not None: _click.secho( - _prefix_lines("\t\t\t\t\t", _render_error(te.closure.error)), bold=True, fg="red", + _prefix_lines("\t\t\t\t\t", _render_error(te.closure.error)), + bold=True, + fg="red", ) if te.is_parent: @@ -1497,7 +1610,8 @@ def get_child_executions(urn, host, insecure, show_io, verbose): _welcome_message() client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure) node_execs = _get_all_node_executions( - client, task_execution_identifier=_identifier.TaskExecutionIdentifier.from_python_std(urn), + client, + task_execution_identifier=_identifier.TaskExecutionIdentifier.from_python_std(urn), ) _render_node_executions(client, node_execs, show_io, verbose, host, insecure) @@ -2197,7 +2311,8 @@ def setup_config(host, insecure): config_dir = _os.path.join(_get_user_filepath_home(), _default_config_file_dir) if not _os.path.isdir(config_dir): _click.secho( - "Creating default Flyte configuration directory at {}".format(_tt(config_dir)), fg="blue", + "Creating default Flyte configuration directory at {}".format(_tt(config_dir)), + fg="blue", ) _os.mkdir(config_dir) diff --git a/flytekit/clis/helpers.py b/flytekit/clis/helpers.py index b764786104..0d264054de 100644 --- a/flytekit/clis/helpers.py +++ b/flytekit/clis/helpers.py @@ -149,7 +149,11 @@ def _hydrate_workflow_template_nodes( def hydrate_registration_parameters( - resource_type: int, project: str, domain: str, version: str, entity: Union[LaunchPlan, WorkflowSpec, TaskSpec], + resource_type: int, + project: str, + domain: str, + version: str, + entity: Union[LaunchPlan, WorkflowSpec, TaskSpec], ) -> Tuple[_identifier_pb2.Identifier, Union[LaunchPlan, WorkflowSpec, TaskSpec]]: """ This is called at registration time to fill out identifier fields (e.g. project, domain, version) that are mutable. diff --git a/flytekit/clis/sdk_in_container/constants.py b/flytekit/clis/sdk_in_container/constants.py index 84d8956cb8..bc2ff68666 100644 --- a/flytekit/clis/sdk_in_container/constants.py +++ b/flytekit/clis/sdk_in_container/constants.py @@ -16,8 +16,16 @@ help="Flyte project to use. You can have more than one project per repo", ) domain_option = _click.option( - "-d", "--domain", required=True, type=str, help="This is usually development, staging, or production", + "-d", + "--domain", + required=True, + type=str, + help="This is usually development, staging, or production", ) version_option = _click.option( - "-v", "--version", required=False, type=str, help="This is the version to apply globally for this context", + "-v", + "--version", + required=False, + type=str, + help="This is the version to apply globally for this context", ) diff --git a/flytekit/clis/sdk_in_container/launch_plan.py b/flytekit/clis/sdk_in_container/launch_plan.py index 49102070b0..e367fa08d8 100644 --- a/flytekit/clis/sdk_in_container/launch_plan.py +++ b/flytekit/clis/sdk_in_container/launch_plan.py @@ -63,7 +63,9 @@ def get_command(self, ctx, lp_argument): launch_plan = ctx.obj["lps"][lp_argument] else: for m, k, lp in iterate_registerable_entities_in_order( - pkgs, include_entities={_SdkLaunchPlan}, detect_unreferenced_entities=False, + pkgs, + include_entities={_SdkLaunchPlan}, + detect_unreferenced_entities=False, ): safe_name = _utils.fqdn(m.__name__, k, entity_type=lp.resource_type) if lp_argument == safe_name: @@ -114,7 +116,10 @@ def _execute_lp(**kwargs): notification_overrides=ctx.obj.get(_constants.CTX_NOTIFICATIONS, None), ) click.echo( - click.style("Workflow scheduled, execution_id={}".format(_six.text_type(execution.id)), fg="blue",) + click.style( + "Workflow scheduled, execution_id={}".format(_six.text_type(execution.id)), + fg="blue", + ) ) command = click.Command(name=cmd_name, callback=_execute_lp) @@ -130,7 +135,12 @@ def _execute_lp(**kwargs): if param.required: # If it's a required input, add the required flag - wrapper = click.option("--{}".format(var_name), required=True, type=_six.text_type, help=help_msg,) + wrapper = click.option( + "--{}".format(var_name), + required=True, + type=_six.text_type, + help=help_msg, + ) else: # If it's not a required input, it should have a default # Use to_python_std so that the text of the default ends up being parseable, if not, the click @@ -217,7 +227,8 @@ def activate_all_schedules(ctx, version=None): The behavior of this command is identical to activate-all. """ click.secho( - "activate-all-schedules is deprecated, please use activate-all instead.", color="yellow", + "activate-all-schedules is deprecated, please use activate-all instead.", + color="yellow", ) project = ctx.obj[_constants.CTX_PROJECT] domain = ctx.obj[_constants.CTX_DOMAIN] @@ -234,7 +245,9 @@ def activate_all_schedules(ctx, version=None): help="Version to register tasks with. This is normally parsed from the" "image, but you can override here.", ) @click.option( - "--ignore-schedules", is_flag=True, help="Activate all except for launch plans with schedules.", + "--ignore-schedules", + is_flag=True, + help="Activate all except for launch plans with schedules.", ) @click.pass_context def activate_all(ctx, version=None, ignore_schedules=False): diff --git a/flytekit/clis/sdk_in_container/pyflyte.py b/flytekit/clis/sdk_in_container/pyflyte.py index a888367ca6..cbb15fe02a 100644 --- a/flytekit/clis/sdk_in_container/pyflyte.py +++ b/flytekit/clis/sdk_in_container/pyflyte.py @@ -20,7 +20,11 @@ @click.group("pyflyte", invoke_without_command=True) @click.option( - "-c", "--config", required=False, type=str, help="Path to config file for use within container", + "-c", + "--config", + required=False, + type=str, + help="Path to config file for use within container", ) @click.option( "-k", @@ -31,7 +35,11 @@ "option will override the option specified in the configuration file, or environment variable", ) @click.option( - "-i", "--insecure", required=False, type=bool, help="Do not use SSL to connect to Admin", + "-i", + "--insecure", + required=False, + type=bool, + help="Do not use SSL to connect to Admin", ) @click.pass_context def main(ctx, config=None, pkgs=None, insecure=None): @@ -71,7 +79,8 @@ def update_configuration_file(config_file_path): configuration_file = Path(config_file_path or CONFIGURATION_PATH.get()) if configuration_file.is_file(): click.secho( - "Using configuration file at {}".format(configuration_file.absolute().as_posix()), fg="green", + "Using configuration file at {}".format(configuration_file.absolute().as_posix()), + fg="green", ) set_flyte_config_file(configuration_file.as_posix()) else: diff --git a/flytekit/clis/sdk_in_container/register.py b/flytekit/clis/sdk_in_container/register.py index 25fe1ee9c7..876629b851 100644 --- a/flytekit/clis/sdk_in_container/register.py +++ b/flytekit/clis/sdk_in_container/register.py @@ -71,7 +71,9 @@ def register_tasks_only(project, domain, pkgs, test, version): @version_option # --pkgs on the register group is DEPRECATED, use same arg on pyflyte.main instead @click.option( - "--pkgs", multiple=True, help="DEPRECATED. This arg can only be used before the 'register' keyword", + "--pkgs", + multiple=True, + help="DEPRECATED. This arg can only be used before the 'register' keyword", ) @click.option("--test", is_flag=True, help="Dry run, do not actually register with Admin") @click.pass_context diff --git a/flytekit/common/core/identifier.py b/flytekit/common/core/identifier.py index 1510bdf293..c7b12a5190 100644 --- a/flytekit/common/core/identifier.py +++ b/flytekit/common/core/identifier.py @@ -21,7 +21,11 @@ def promote_from_model(cls, base_model): :rtype: Identifier """ return cls( - base_model.resource_type, base_model.project, base_model.domain, base_model.name, base_model.version, + base_model.resource_type, + base_model.project, + base_model.domain, + base_model.name, + base_model.version, ) @classmethod @@ -66,7 +70,11 @@ def promote_from_model(cls, base_model): :param flytekit.models.core.identifier.WorkflowExecutionIdentifier base_model: :rtype: WorkflowExecutionIdentifier """ - return cls(base_model.project, base_model.domain, base_model.name,) + return cls( + base_model.project, + base_model.domain, + base_model.name, + ) @classmethod def from_python_std(cls, string): @@ -91,7 +99,11 @@ def from_python_std(cls, string): "The provided string could not be parsed. The first element of an execution identifier must be 'ex'.", ) - return cls(project, domain, name,) + return cls( + project, + domain, + name, + ) def __str__(self): return "ex:{}:{}:{}".format(self.project, self.domain, self.name) @@ -136,7 +148,8 @@ def from_python_std(cls, string): return cls( task_id=Identifier(_core_identifier.ResourceType.TASK, tp, td, tn, tv), node_execution_id=_core_identifier.NodeExecutionIdentifier( - node_id=node_id, execution_id=_core_identifier.WorkflowExecutionIdentifier(ep, ed, en), + node_id=node_id, + execution_id=_core_identifier.WorkflowExecutionIdentifier(ep, ed, en), ), retry_attempt=int(retry), ) diff --git a/flytekit/common/exceptions/scopes.py b/flytekit/common/exceptions/scopes.py index fdd4e1a802..994211f6a6 100644 --- a/flytekit/common/exceptions/scopes.py +++ b/flytekit/common/exceptions/scopes.py @@ -159,7 +159,9 @@ def system_entry_point(wrapped, instance, args, kwargs): except _user_exceptions.FlyteUserException: # Re-raise from here. _reraise( - FlyteScopedUserException, FlyteScopedUserException(*_exc_info()), _exc_info()[2], + FlyteScopedUserException, + FlyteScopedUserException(*_exc_info()), + _exc_info()[2], ) except Exception: # System error, raise full stack-trace all the way up the chain. @@ -198,17 +200,23 @@ def user_entry_point(wrapped, instance, args, kwargs): _reraise(*_exc_info()) except _user_exceptions.FlyteUserException: _reraise( - FlyteScopedUserException, FlyteScopedUserException(*_exc_info()), _exc_info()[2], + FlyteScopedUserException, + FlyteScopedUserException(*_exc_info()), + _exc_info()[2], ) except _system_exceptions.FlyteSystemException: _reraise( - FlyteScopedSystemException, FlyteScopedSystemException(*_exc_info()), _exc_info()[2], + FlyteScopedSystemException, + FlyteScopedSystemException(*_exc_info()), + _exc_info()[2], ) except Exception: # Any non-platform raised exception is a user exception. # This will also catch FlyteUserException re-raised by the system_entry_point handler _reraise( - FlyteScopedUserException, FlyteScopedUserException(*_exc_info()), _exc_info()[2], + FlyteScopedUserException, + FlyteScopedUserException(*_exc_info()), + _exc_info()[2], ) finally: _CONTEXT_STACK.pop() diff --git a/flytekit/common/exceptions/user.py b/flytekit/common/exceptions/user.py index 671ebf66ec..acb5dd7997 100644 --- a/flytekit/common/exceptions/user.py +++ b/flytekit/common/exceptions/user.py @@ -34,7 +34,10 @@ def _create_verbose_message(cls, received_type, expected_type, received_value=No def __init__(self, received_type, expected_type, additional_msg=None, received_value=None): super(FlyteTypeException, self).__init__( self._create_verbose_message( - received_type, expected_type, received_value=received_value, additional_msg=additional_msg, + received_type, + expected_type, + received_value=received_value, + additional_msg=additional_msg, ) ) diff --git a/flytekit/common/interface.py b/flytekit/common/interface.py index 0f8e5569fa..c8f5160e15 100644 --- a/flytekit/common/interface.py +++ b/flytekit/common/interface.py @@ -30,7 +30,12 @@ def promote_from_model(cls, model): :param flytekit.models.literals.BindingData model: :rtype: BindingData """ - return cls(scalar=model.scalar, collection=model.collection, promise=model.promise, map=model.map,) + return cls( + scalar=model.scalar, + collection=model.collection, + promise=model.promise, + map=model.map, + ) @classmethod def from_python_std(cls, literal_type, t_value, upstream_nodes=None): @@ -75,7 +80,9 @@ def from_python_std(cls, literal_type, t_value, upstream_nodes=None): collection = _literal_models.BindingDataCollection( [ BindingData.from_python_std( - downstream_sdk_type.sub_type.to_flyte_literal_type(), v, upstream_nodes=upstream_nodes, + downstream_sdk_type.sub_type.to_flyte_literal_type(), + v, + upstream_nodes=upstream_nodes, ) for v in t_value ] diff --git a/flytekit/common/launch_plan.py b/flytekit/common/launch_plan.py index 0cb1cc872e..7098d3e556 100644 --- a/flytekit/common/launch_plan.py +++ b/flytekit/common/launch_plan.py @@ -184,7 +184,8 @@ def auth_role(self): ) assumable_iam_role = _sdk_config.ROLE.get() return _common_models.AuthRole( - assumable_iam_role=assumable_iam_role, kubernetes_service_account=kubernetes_service_account, + assumable_iam_role=assumable_iam_role, + kubernetes_service_account=kubernetes_service_account, ) @property @@ -274,7 +275,13 @@ def execute_with_literals( Deprecated. """ return self.launch_with_literals( - project, domain, literal_inputs, name, notification_overrides, label_overrides, annotation_overrides, + project, + domain, + literal_inputs, + name, + notification_overrides, + label_overrides, + annotation_overrides, ) @_exception_scopes.system_entry_point @@ -425,7 +432,8 @@ def __init__( super(SdkRunnableLaunchPlan, self).__init__( None, _launch_plan_models.LaunchPlanMetadata( - schedule=schedule or _schedule_model.Schedule(""), notifications=notifications or [], + schedule=schedule or _schedule_model.Schedule(""), + notifications=notifications or [], ), _interface_models.ParameterMap(default_inputs), _type_helpers.pack_python_std_map_to_literal_map( @@ -442,7 +450,8 @@ def __init__( raw_output_data_config or _common_models.RawOutputDataConfig(""), ) self._interface = _interface.TypedInterface( - {k: v.var for k, v in _six.iteritems(default_inputs)}, sdk_workflow.interface.outputs, + {k: v.var for k, v in _six.iteritems(default_inputs)}, + sdk_workflow.interface.outputs, ) self._upstream_entities = {sdk_workflow} self._sdk_workflow = sdk_workflow diff --git a/flytekit/common/local_workflow.py b/flytekit/common/local_workflow.py index 9aaecf5385..eb2067578f 100644 --- a/flytekit/common/local_workflow.py +++ b/flytekit/common/local_workflow.py @@ -280,7 +280,8 @@ class provided should be a subclass of flytekit.common.launch_plan.SdkLaunchPlan if role: assumable_iam_role = role # For backwards compatibility auth_role = _common_models.AuthRole( - assumable_iam_role=assumable_iam_role, kubernetes_service_account=kubernetes_service_account, + assumable_iam_role=assumable_iam_role, + kubernetes_service_account=kubernetes_service_account, ) raw_output_config = _common_models.RawOutputDataConfig(raw_output_data_prefix or "") @@ -359,13 +360,15 @@ def _discover_workflow_components(workflow_class): elif isinstance(current_obj, _promise.Input): if attribute_name is None or attribute_name not in top_level_attributes: raise _user_exceptions.FlyteValueException( - attribute_name, "Detected workflow input specified outside of top level.", + attribute_name, + "Detected workflow input specified outside of top level.", ) inputs.append(current_obj.rename_and_return_reference(attribute_name)) elif isinstance(current_obj, Output): if attribute_name is None or attribute_name not in top_level_attributes: raise _user_exceptions.FlyteValueException( - attribute_name, "Detected workflow output specified outside of top level.", + attribute_name, + "Detected workflow output specified outside of top level.", ) outputs.append(current_obj.rename_and_return_reference(attribute_name)) elif isinstance(current_obj, list) or isinstance(current_obj, set) or isinstance(current_obj, tuple): diff --git a/flytekit/common/mixins/launchable.py b/flytekit/common/mixins/launchable.py index 689623e867..110ba663af 100644 --- a/flytekit/common/mixins/launchable.py +++ b/flytekit/common/mixins/launchable.py @@ -119,5 +119,11 @@ def execute_with_literals( Deprecated. """ return self.launch_with_literals( - project, domain, literal_inputs, name, notification_overrides, label_overrides, annotation_overrides, + project, + domain, + literal_inputs, + name, + notification_overrides, + label_overrides, + annotation_overrides, ) diff --git a/flytekit/common/promise.py b/flytekit/common/promise.py index 2cf3bfba50..79352dd638 100644 --- a/flytekit/common/promise.py +++ b/flytekit/common/promise.py @@ -109,7 +109,13 @@ def promote_from_model(cls, model): if model.default is not None: default_value = sdk_type.from_flyte_idl(model.default.to_flyte_idl()).to_python_std() - return cls("", sdk_type, help=model.var.description, required=False, default=default_value,) + return cls( + "", + sdk_type, + help=model.var.description, + required=False, + default=default_value, + ) else: return cls("", sdk_type, help=model.var.description, required=True) diff --git a/flytekit/common/schedules.py b/flytekit/common/schedules.py index 250b7aff39..d6e71c6f83 100644 --- a/flytekit/common/schedules.py +++ b/flytekit/common/schedules.py @@ -165,15 +165,18 @@ def _translate_duration(duration): ) elif int(duration.total_seconds()) % _SECONDS_TO_DAYS == 0: return _schedule_models.Schedule.FixedRate( - int(duration.total_seconds() / _SECONDS_TO_DAYS), _schedule_models.Schedule.FixedRateUnit.DAY, + int(duration.total_seconds() / _SECONDS_TO_DAYS), + _schedule_models.Schedule.FixedRateUnit.DAY, ) elif int(duration.total_seconds()) % _SECONDS_TO_HOURS == 0: return _schedule_models.Schedule.FixedRate( - int(duration.total_seconds() / _SECONDS_TO_HOURS), _schedule_models.Schedule.FixedRateUnit.HOUR, + int(duration.total_seconds() / _SECONDS_TO_HOURS), + _schedule_models.Schedule.FixedRateUnit.HOUR, ) else: return _schedule_models.Schedule.FixedRate( - int(duration.total_seconds() / _SECONDS_TO_MINUTES), _schedule_models.Schedule.FixedRateUnit.MINUTE, + int(duration.total_seconds() / _SECONDS_TO_MINUTES), + _schedule_models.Schedule.FixedRateUnit.MINUTE, ) @classmethod diff --git a/flytekit/common/tasks/generic_spark_task.py b/flytekit/common/tasks/generic_spark_task.py index 3a7c0ef1bc..a83a7afbe4 100644 --- a/flytekit/common/tasks/generic_spark_task.py +++ b/flytekit/common/tasks/generic_spark_task.py @@ -76,7 +76,11 @@ def __init__( task_type, _task_models.TaskMetadata( discoverable, - _task_models.RuntimeMetadata(_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, "spark",), + _task_models.RuntimeMetadata( + _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, + __version__, + "spark", + ), timeout, _literal_models.RetryStrategy(retries), interruptible, @@ -121,7 +125,8 @@ def add_inputs(self, inputs): self.interface.inputs.update(inputs) def _get_container_definition( - self, environment=None, + self, + environment=None, ): """ :rtype: Container diff --git a/flytekit/common/tasks/hive_task.py b/flytekit/common/tasks/hive_task.py index 6023bf0ec3..6be5db70f9 100644 --- a/flytekit/common/tasks/hive_task.py +++ b/flytekit/common/tasks/hive_task.py @@ -111,7 +111,9 @@ def _generate_plugin_objects(self, context, inputs_dict): for q in queries_from_task: hive_query = _qubole.HiveQuery( - query=q, timeout_sec=self.metadata.timeout.seconds, retry_count=self.metadata.retries.retries, + query=q, + timeout_sec=self.metadata.timeout.seconds, + retry_count=self.metadata.retries.retries, ) # TODO: Remove this after all users of older SDK versions that did the single node, multi-query pattern are @@ -121,7 +123,12 @@ def _generate_plugin_objects(self, context, inputs_dict): query_collection = _qubole.HiveQueryCollection([hive_query]) plugin_objects.append( - _qubole.QuboleHiveJob(hive_query, self._cluster_label, self._tags, query_collection=query_collection,) + _qubole.QuboleHiveJob( + hive_query, + self._cluster_label, + self._tags, + query_collection=query_collection, + ) ) return plugin_objects @@ -145,11 +152,13 @@ def _validate_task_parameters(cluster_label, tags): ) if len(tags) > ALLOWED_TAGS_COUNT: raise _FlyteValueException( - len(tags), "number of tags must be less than {}".format(ALLOWED_TAGS_COUNT), + len(tags), + "number of tags must be less than {}".format(ALLOWED_TAGS_COUNT), ) if not all(len(tag) for tag in tags): raise _FlyteValueException( - tags, "length of a tag must be less than {} chars".format(MAX_TAG_LENGTH), + tags, + "length of a tag must be less than {} chars".format(MAX_TAG_LENGTH), ) @staticmethod @@ -190,7 +199,8 @@ def _produce_dynamic_job_spec(self, context, inputs): # Create output bindings always - this has to happen after user code has run output_bindings = [ _literal_models.Binding( - var=name, binding=_interface.BindingData.from_python_std(b.sdk_type.to_flyte_literal_type(), b.value), + var=name, + binding=_interface.BindingData.from_python_std(b.sdk_type.to_flyte_literal_type(), b.value), ) for name, b in _six.iteritems(outputs_dict) ] @@ -203,7 +213,11 @@ def _produce_dynamic_job_spec(self, context, inputs): i += 1 dynamic_job_spec = _dynamic_job.DynamicJobSpec( - min_successes=len(nodes), tasks=tasks, nodes=nodes, outputs=output_bindings, subworkflows=[], + min_successes=len(nodes), + tasks=tasks, + nodes=nodes, + outputs=output_bindings, + subworkflows=[], ) return dynamic_job_spec @@ -263,7 +277,9 @@ class SdkHiveJob(_base_task.SdkTask): """ def __init__( - self, hive_job, metadata, + self, + hive_job, + metadata, ): """ :param _qubole.QuboleHiveJob hive_job: Hive job spec diff --git a/flytekit/common/tasks/presto_task.py b/flytekit/common/tasks/presto_task.py index 1bf8d0e51c..47e198494d 100644 --- a/flytekit/common/tasks/presto_task.py +++ b/flytekit/common/tasks/presto_task.py @@ -69,7 +69,10 @@ def __init__( ) presto_query = _presto_models.PrestoQuery( - routing_group=routing_group or "", catalog=catalog or "", schema=schema or "", statement=statement, + routing_group=routing_group or "", + catalog=catalog or "", + schema=schema or "", + statement=statement, ) # Here we set the routing_group, catalog, and schema as implicit @@ -99,7 +102,10 @@ def __init__( ) super(SdkPrestoTask, self).__init__( - _constants.SdkTaskType.PRESTO_TASK, metadata, i, _MessageToDict(presto_query.to_flyte_idl()), + _constants.SdkTaskType.PRESTO_TASK, + metadata, + i, + _MessageToDict(presto_query.to_flyte_idl()), ) # Set user provided inputs diff --git a/flytekit/common/tasks/raw_container.py b/flytekit/common/tasks/raw_container.py index 9f8a5ccc1b..c5437477a9 100644 --- a/flytekit/common/tasks/raw_container.py +++ b/flytekit/common/tasks/raw_container.py @@ -158,7 +158,9 @@ def __init__( discoverable, # This needs to have the proper version reflected in it _task_models.RuntimeMetadata( - _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, flytekit.__version__, "python", + _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, + flytekit.__version__, + "python", ), timeout or _datetime.timedelta(seconds=0), _literals.RetryStrategy(retries), diff --git a/flytekit/common/tasks/sagemaker/built_in_training_job_task.py b/flytekit/common/tasks/sagemaker/built_in_training_job_task.py index e1e235df09..356f933c3e 100644 --- a/flytekit/common/tasks/sagemaker/built_in_training_job_task.py +++ b/flytekit/common/tasks/sagemaker/built_in_training_job_task.py @@ -41,7 +41,8 @@ def __init__( """ # Use the training job model as a measure of type checking self._training_job_model = _training_job_models.TrainingJob( - algorithm_specification=algorithm_specification, training_job_resource_config=training_job_resource_config, + algorithm_specification=algorithm_specification, + training_job_resource_config=training_job_resource_config, ) # Setting flyte-level timeout to 0, and let SageMaker takes the StoppingCondition and terminate the training @@ -52,7 +53,9 @@ def __init__( type=SdkTaskType.SAGEMAKER_TRAINING_JOB_TASK, metadata=_task_models.TaskMetadata( runtime=_task_models.RuntimeMetadata( - type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, version=__version__, flavor="sagemaker", + type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, + version=__version__, + flavor="sagemaker", ), discoverable=cacheable, timeout=timeout, @@ -64,7 +67,8 @@ def __init__( interface=_interface.TypedInterface( inputs={ "static_hyperparameters": _interface_model.Variable( - type=_idl_types.LiteralType(simple=_idl_types.SimpleType.STRUCT), description="", + type=_idl_types.LiteralType(simple=_idl_types.SimpleType.STRUCT), + description="", ), "train": _interface_model.Variable( type=_idl_types.LiteralType( @@ -89,7 +93,8 @@ def __init__( "model": _interface_model.Variable( type=_idl_types.LiteralType( blob=_core_types.BlobType( - format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, + format="", + dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, ) ), description="", diff --git a/flytekit/common/tasks/sagemaker/hpo_job_task.py b/flytekit/common/tasks/sagemaker/hpo_job_task.py index 518816cf2b..fd4d0a3ce1 100644 --- a/flytekit/common/tasks/sagemaker/hpo_job_task.py +++ b/flytekit/common/tasks/sagemaker/hpo_job_task.py @@ -63,7 +63,8 @@ def __init__( inputs.update( { "hyperparameter_tuning_job_config": _interface_model.Variable( - HyperparameterTuningJobConfig.to_flyte_literal_type(), "", + HyperparameterTuningJobConfig.to_flyte_literal_type(), + "", ), } ) @@ -80,7 +81,9 @@ def __init__( type=SdkTaskType.SAGEMAKER_HYPERPARAMETER_TUNING_JOB_TASK, metadata=_task_models.TaskMetadata( runtime=_task_models.RuntimeMetadata( - type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, version=__version__, flavor="sagemaker", + type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, + version=__version__, + flavor="sagemaker", ), discoverable=cacheable, timeout=timeout, @@ -95,7 +98,8 @@ def __init__( "model": _interface_model.Variable( type=_types_models.LiteralType( blob=_core_types.BlobType( - format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, + format="", + dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, ) ), description="", diff --git a/flytekit/common/tasks/sdk_dynamic.py b/flytekit/common/tasks/sdk_dynamic.py index e1ffc6f861..2a44111be5 100644 --- a/flytekit/common/tasks/sdk_dynamic.py +++ b/flytekit/common/tasks/sdk_dynamic.py @@ -81,7 +81,9 @@ def _create_array_job(self, inputs_prefix): :rtype: _array_job.ArrayJob """ return _array_job.ArrayJob( - parallelism=self._max_concurrency if self._max_concurrency else 0, size=1, min_successes=1, + parallelism=self._max_concurrency if self._max_concurrency else 0, + size=1, + min_successes=1, ) @staticmethod @@ -137,7 +139,9 @@ def _produce_dynamic_job_spec(self, context, inputs): _literal_models.Binding( var=name, binding=_interface.BindingData.from_python_std( - b.sdk_type.to_flyte_literal_type(), b.raw_value, upstream_nodes=upstream_nodes, + b.sdk_type.to_flyte_literal_type(), + b.raw_value, + upstream_nodes=upstream_nodes, ), ) for name, b in _six.iteritems(outputs_dict) @@ -284,7 +288,9 @@ def execute(self, context, inputs): class SdkDynamicTask( - SdkDynamicTaskMixin, _sdk_runnable.SdkRunnableTask, metaclass=_sdk_bases.ExtendedSdkType, + SdkDynamicTaskMixin, + _sdk_runnable.SdkRunnableTask, + metaclass=_sdk_bases.ExtendedSdkType, ): """ diff --git a/flytekit/common/tasks/sdk_runnable.py b/flytekit/common/tasks/sdk_runnable.py index 522eccdc7c..d938d51458 100644 --- a/flytekit/common/tasks/sdk_runnable.py +++ b/flytekit/common/tasks/sdk_runnable.py @@ -242,7 +242,12 @@ class SdkRunnableContainer(_task_models.Container, metaclass=_sdk_bases.Extended """ def __init__( - self, command, args, resources, env, config, + self, + command, + args, + resources, + env, + config, ): super(SdkRunnableContainer, self).__init__("", command, args, resources, env or {}, config) @@ -396,7 +401,9 @@ def __init__( _task_models.TaskMetadata( discoverable, _task_models.RuntimeMetadata( - _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, "python", + _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, + __version__, + "python", ), timeout, _literal_models.RetryStrategy(retries), diff --git a/flytekit/common/tasks/sidecar_task.py b/flytekit/common/tasks/sidecar_task.py index bf1f6019ec..51be8a8f40 100644 --- a/flytekit/common/tasks/sidecar_task.py +++ b/flytekit/common/tasks/sidecar_task.py @@ -132,14 +132,17 @@ def reconcile_partial_pod_spec_and_task(self, pod_spec, primary_container_name): pod_spec.containers.extend(final_containers) sidecar_job_plugin = _task_models.SidecarJob( - pod_spec=pod_spec, primary_container_name=primary_container_name, + pod_spec=pod_spec, + primary_container_name=primary_container_name, ).to_flyte_idl() self.assign_custom_and_return(_MessageToDict(sidecar_job_plugin)) class SdkDynamicSidecarTask( - _sdk_dynamic.SdkDynamicTaskMixin, SdkSidecarTask, metaclass=_sdk_bases.ExtendedSdkType, + _sdk_dynamic.SdkDynamicTaskMixin, + SdkSidecarTask, + metaclass=_sdk_bases.ExtendedSdkType, ): """ diff --git a/flytekit/common/tasks/task.py b/flytekit/common/tasks/task.py index 07e4cee420..ba55399382 100644 --- a/flytekit/common/tasks/task.py +++ b/flytekit/common/tasks/task.py @@ -143,7 +143,10 @@ def __call__(self, *args, **input_map): return _nodes.SdkNode( id=None, metadata=_workflow_model.NodeMetadata( - "DEADBEEF", self.metadata.timeout, self.metadata.retries, self.metadata.interruptible, + "DEADBEEF", + self.metadata.timeout, + self.metadata.retries, + self.metadata.interruptible, ), bindings=sorted(bindings, key=lambda b: b.var), upstream_nodes=upstream_nodes, @@ -216,7 +219,9 @@ def fetch_latest(cls, project, domain, name): named_task = _common_model.NamedEntityIdentifier(project, domain, name) client = _flyte_engine.get_client() task_list, _ = client.list_tasks_paginated( - named_task, limit=1, sort_by=_admin_common.Sort("created_at", _admin_common.Sort.Direction.DESCENDING), + named_task, + limit=1, + sort_by=_admin_common.Sort("created_at", _admin_common.Sort.Direction.DESCENDING), ) admin_task = task_list[0] if task_list else None @@ -386,7 +391,8 @@ def launch_with_literals( ) assumable_iam_role = _sdk_config.ROLE.get() auth_role = _common_model.AuthRole( - assumable_iam_role=assumable_iam_role, kubernetes_service_account=kubernetes_service_account, + assumable_iam_role=assumable_iam_role, + kubernetes_service_account=kubernetes_service_account, ) client = _flyte_engine.get_client() diff --git a/flytekit/common/translator.py b/flytekit/common/translator.py index fc4014d9ac..4a4218f07f 100644 --- a/flytekit/common/translator.py +++ b/flytekit/common/translator.py @@ -61,7 +61,10 @@ def to_serializable_cases( def get_serializable_references( - entity_mapping: OrderedDict, settings: SerializationSettings, entity: FlyteLocalEntity, fast: bool, + entity_mapping: OrderedDict, + settings: SerializationSettings, + entity: FlyteLocalEntity, + fast: bool, ) -> FlyteControlPlaneEntity: # TODO: This entire function isn't necessary. We should just return None or raise an Exception or something. # Reference entities should already exist on the Admin control plane - they should not be serialized/registered @@ -114,7 +117,10 @@ def get_serializable_references( def get_serializable_task( - entity_mapping: OrderedDict, settings: SerializationSettings, entity: FlyteLocalEntity, fast: bool, + entity_mapping: OrderedDict, + settings: SerializationSettings, + entity: FlyteLocalEntity, + fast: bool, ) -> FlyteControlPlaneEntity: cp_entity = SdkTask( type=entity.task_type, @@ -152,7 +158,10 @@ def get_serializable_task( def get_serializable_workflow( - entity_mapping: OrderedDict, settings: SerializationSettings, entity: WorkflowBase, fast: bool, + entity_mapping: OrderedDict, + settings: SerializationSettings, + entity: WorkflowBase, + fast: bool, ) -> FlyteControlPlaneEntity: workflow_id = _identifier_model.Identifier( _identifier_model.ResourceType.WORKFLOW, settings.project, settings.domain, entity.name, settings.version @@ -182,13 +191,17 @@ def get_serializable_workflow( def get_serializable_launch_plan( - entity_mapping: OrderedDict, settings: SerializationSettings, entity: FlyteLocalEntity, fast: bool, + entity_mapping: OrderedDict, + settings: SerializationSettings, + entity: FlyteLocalEntity, + fast: bool, ) -> FlyteControlPlaneEntity: sdk_workflow = get_serializable(entity_mapping, settings, entity.workflow) cp_entity = SdkLaunchPlan( workflow_id=sdk_workflow.id, entity_metadata=_launch_plan_models.LaunchPlanMetadata( - schedule=entity.schedule, notifications=entity.notifications, + schedule=entity.schedule, + notifications=entity.notifications, ), default_inputs=entity.parameters, fixed_inputs=entity.fixed_inputs, @@ -213,7 +226,10 @@ def get_serializable_launch_plan( def get_serializable_node( - entity_mapping: OrderedDict, settings: SerializationSettings, entity: FlyteLocalEntity, fast: bool, + entity_mapping: OrderedDict, + settings: SerializationSettings, + entity: FlyteLocalEntity, + fast: bool, ) -> FlyteControlPlaneEntity: if entity._flyte_entity is None: raise Exception(f"Node {entity.id} has no flyte entity") @@ -269,7 +285,10 @@ def get_serializable_node( def get_serializable_branch_node( - entity_mapping: OrderedDict, settings: SerializationSettings, entity: FlyteLocalEntity, fast: bool, + entity_mapping: OrderedDict, + settings: SerializationSettings, + entity: FlyteLocalEntity, + fast: bool, ) -> FlyteControlPlaneEntity: # We have to iterate through the blocks to convert the nodes from their current type to SDKNode # TODO this should be cleaned up instead of mutation, we probaby should just create a new object diff --git a/flytekit/common/types/blobs.py b/flytekit/common/types/blobs.py index 206fbf198f..7870cb75bd 100644 --- a/flytekit/common/types/blobs.py +++ b/flytekit/common/types/blobs.py @@ -168,7 +168,8 @@ def from_string(cls, string_value): """ if not string_value: _user_exceptions.FlyteValueException( - string_value, "Cannot create a MultiPartBlob from the provided path " "value.", + string_value, + "Cannot create a MultiPartBlob from the provided path " "value.", ) return cls(_blob_impl.MultiPartBlob.from_string(string_value, mode="rb")) @@ -201,7 +202,10 @@ def to_flyte_literal_type(cls): :rtype: flytekit.models.types.LiteralType """ return _idl_types.LiteralType( - blob=_core_types.BlobType(format="", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART,) + blob=_core_types.BlobType( + format="", + dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, + ) ) @classmethod @@ -321,7 +325,10 @@ def to_flyte_literal_type(cls): :rtype: flytekit.models.types.LiteralType """ return _idl_types.LiteralType( - blob=_core_types.BlobType(format="csv", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE,) + blob=_core_types.BlobType( + format="csv", + dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, + ) ) @classmethod @@ -391,7 +398,8 @@ def from_string(cls, string_value): """ if not string_value: _user_exceptions.FlyteValueException( - string_value, "Cannot create a MultiPartCSV from the provided path value.", + string_value, + "Cannot create a MultiPartCSV from the provided path value.", ) return cls(_blob_impl.MultiPartBlob.from_string(string_value, format="csv", mode="r")) @@ -428,7 +436,10 @@ def to_flyte_literal_type(cls): :rtype: flytekit.models.types.LiteralType """ return _idl_types.LiteralType( - blob=_core_types.BlobType(format="csv", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART,) + blob=_core_types.BlobType( + format="csv", + dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, + ) ) @classmethod diff --git a/flytekit/common/types/containers.py b/flytekit/common/types/containers.py index f61a882bd0..9267273b1e 100644 --- a/flytekit/common/types/containers.py +++ b/flytekit/common/types/containers.py @@ -58,12 +58,16 @@ def from_string(cls, string_value): items = _json.loads(string_value) except ValueError: raise _user_exceptions.FlyteTypeException( - _six.text_type, cls, additional_msg="String not parseable to json {}".format(string_value), + _six.text_type, + cls, + additional_msg="String not parseable to json {}".format(string_value), ) if type(items) != list: raise _user_exceptions.FlyteTypeException( - _six.text_type, cls, additional_msg="String is not a list {}".format(string_value), + _six.text_type, + cls, + additional_msg="String is not a list {}".format(string_value), ) # Instead of recursively calling from_string(), we're changing to from_python_std() instead because json @@ -137,7 +141,9 @@ def short_string(self): if len(self.collection.literals) > num_to_print: to_print.append("...") return "{}(len={}, [{}])".format( - type(self).short_class_string(), len(self.collection.literals), ", ".join(to_print), + type(self).short_class_string(), + len(self.collection.literals), + ", ".join(to_print), ) def verbose_string(self): diff --git a/flytekit/common/types/impl/blobs.py b/flytekit/common/types/impl/blobs.py index f69d4215b8..45210da62a 100644 --- a/flytekit/common/types/impl/blobs.py +++ b/flytekit/common/types/impl/blobs.py @@ -349,7 +349,8 @@ def __enter__(self): "path is specified." ) self._directory = _utils.AutoDeletingTempDir( - _uuid.uuid4().hex, tmp_dir=_data_proxy.LocalWorkingDirectoryContext.get().name, + _uuid.uuid4().hex, + tmp_dir=_data_proxy.LocalWorkingDirectoryContext.get().name, ) self._is_managed = True self._directory.__enter__() @@ -361,7 +362,10 @@ def __enter__(self): self._blobs = [] file_handles = [] for local_path in sorted(self._directory.list_dir(), key=lambda x: x.lower()): - b = Blob(_os.path.join(self.remote_location, _os.path.basename(local_path)), mode=self.mode,) + b = Blob( + _os.path.join(self.remote_location, _os.path.basename(local_path)), + mode=self.mode, + ) b._local_path = local_path file_handles.append(b.__enter__()) self._blobs.append(b) @@ -426,10 +430,13 @@ def create_part(self, name=None): name = _uuid.uuid4().hex if ":" in name or "/" in name: raise _user_exceptions.FlyteAssertion( - name, "Cannot create a part of a multi-part object with ':' or '/' in the name.", + name, + "Cannot create a part of a multi-part object with ':' or '/' in the name.", ) return Blob.create_at_known_location( - _os.path.join(self.remote_location, name), mode=self.mode, format=self.metadata.type.format, + _os.path.join(self.remote_location, name), + mode=self.mode, + format=self.metadata.type.format, ) @_exception_scopes.system_entry_point diff --git a/flytekit/common/types/impl/schema.py b/flytekit/common/types/impl/schema.py index 10a28eaa29..04e4109d1c 100644 --- a/flytekit/common/types/impl/schema.py +++ b/flytekit/common/types/impl/schema.py @@ -352,7 +352,9 @@ def write(self, data_frame, coerce_timestamps="us", allow_truncated_timestamps=F try: filename = self._local_dir.get_named_tempfile(_os.path.join(str(self._index).zfill(6))) data_frame.to_parquet( - filename, coerce_timestamps=coerce_timestamps, allow_truncated_timestamps=allow_truncated_timestamps, + filename, + coerce_timestamps=coerce_timestamps, + allow_truncated_timestamps=allow_truncated_timestamps, ) if self._index == len(self._chunks): self._chunks.append(filename) @@ -379,7 +381,8 @@ def __enter__(self): "specify a path when calling this function." ) self._directory = _utils.AutoDeletingTempDir( - _uuid.uuid4().hex, tmp_dir=_data_proxy.LocalWorkingDirectoryContext.get().name, + _uuid.uuid4().hex, + tmp_dir=_data_proxy.LocalWorkingDirectoryContext.get().name, ) self._is_managed = True self._directory.__enter__() @@ -630,7 +633,12 @@ def from_string(cls, string_value, schema_type=None): @classmethod @_exception_scopes.system_entry_point def create_from_hive_query( - cls, select_query, stage_query=None, schema_to_table_name_map=None, schema_type=None, known_location=None, + cls, + select_query, + stage_query=None, + schema_to_table_name_map=None, + schema_type=None, + known_location=None, ): """ Returns a query that can be submitted to Hive and produce the desired output. It also returns a properly-typed @@ -647,7 +655,9 @@ def create_from_hive_query( :return: Schema, Text """ schema_object = cls( - known_location or _data_proxy.Data.get_remote_directory(), mode="wb", schema_type=schema_type, + known_location or _data_proxy.Data.get_remote_directory(), + mode="wb", + schema_type=schema_type, ) if len(schema_object.type.sdk_columns) > 0: @@ -660,13 +670,15 @@ def create_from_hive_query( if sdk_type == _primitives.Float: columnar_clauses.append( "CAST({table_column_name} as double) {schema_name}".format( - table_column_name=schema_to_table_name_map[name], schema_name=name, + table_column_name=schema_to_table_name_map[name], + schema_name=name, ) ) else: columnar_clauses.append( "{table_column_name} as {schema_name}".format( - table_column_name=schema_to_table_name_map[name], schema_name=name, + table_column_name=schema_to_table_name_map[name], + schema_name=name, ) ) columnar_query = ",\n\t\t".join(columnar_clauses) @@ -844,7 +856,8 @@ def get_write_partition_to_hive_table_query( for partition_name, partition_value in partitions: where_clauses.append( "\n\t\t{schema_name} = {value_str} AND ".format( - schema_name=table_to_schema_name_map[partition_name], value_str=partition_value, + schema_name=table_to_schema_name_map[partition_name], + value_str=partition_value, ) ) where_string = "WHERE\n\t\t{where_clauses}".format(where_clauses=" AND\n\t\t".join(where_clauses)) @@ -863,7 +876,9 @@ def get_write_partition_to_hive_table_query( ) return _format_insert_partition_query( - remote_location=self.remote_location, table_name=table_name, partition_string=partition_string, + remote_location=self.remote_location, + table_name=table_name, + partition_string=partition_string, ) def compare_dataframe_to_schema(self, data_frame, column_subset=None, read=False): diff --git a/flytekit/common/types/primitives.py b/flytekit/common/types/primitives.py index 77fbf7b313..6fa61c01ab 100644 --- a/flytekit/common/types/primitives.py +++ b/flytekit/common/types/primitives.py @@ -183,7 +183,9 @@ def from_string(cls, string_value): elif string_value == "0" or string_value.lower() == "false": return cls(False) raise _user_exceptions.FlyteTypeException( - _six.text_type, bool, additional_msg="String not castable to Boolean SDK " "type: {}".format(string_value), + _six.text_type, + bool, + additional_msg="String not castable to Boolean SDK " "type: {}".format(string_value), ) @classmethod @@ -377,7 +379,8 @@ def from_python_std(cls, t_value): raise _user_exceptions.FlyteTypeException(type(t_value), _datetime.datetime, t_value) elif t_value.tzinfo is None: raise _user_exceptions.FlyteValueException( - t_value, "Datetime objects in Flyte must be timezone aware. " "tzinfo was found to be None.", + t_value, + "Datetime objects in Flyte must be timezone aware. " "tzinfo was found to be None.", ) return cls(t_value) diff --git a/flytekit/common/types/schema.py b/flytekit/common/types/schema.py index 8f0a9555d1..eaf38d1c88 100644 --- a/flytekit/common/types/schema.py +++ b/flytekit/common/types/schema.py @@ -29,7 +29,11 @@ def create(cls): return _schema_impl.Schema.create_at_any_location(mode="wb", schema_type=cls.schema_type) def create_from_hive_query( - cls, select_query, stage_query=None, schema_to_table_name_map=None, known_location=None, + cls, + select_query, + stage_query=None, + schema_to_table_name_map=None, + known_location=None, ): """ Returns a query that can be submitted to Hive and produce the desired output. It also returns a properly-typed @@ -150,7 +154,9 @@ def short_string(self): """ :rtype: Text """ - return "{}".format(self.scalar.schema,) + return "{}".format( + self.scalar.schema, + ) def schema_instantiator(columns=None): diff --git a/flytekit/common/workflow.py b/flytekit/common/workflow.py index 42cb946579..3fc4f498fd 100644 --- a/flytekit/common/workflow.py +++ b/flytekit/common/workflow.py @@ -39,7 +39,13 @@ class only a control plane class. Workflow constructs that rely on local code be """ def __init__( - self, nodes, interface, output_bindings, id, metadata, metadata_defaults, + self, + nodes, + interface, + output_bindings, + id, + metadata, + metadata_defaults, ): """ :param list[flytekit.common.nodes.SdkNode] nodes: @@ -221,7 +227,13 @@ def register(self, project, domain, name, version): try: client = _flyte_engine.get_client() sub_workflows = self.get_sub_workflows() - client.create_workflow(id_to_register, _admin_workflow_model.WorkflowSpec(self, sub_workflows,)) + client.create_workflow( + id_to_register, + _admin_workflow_model.WorkflowSpec( + self, + sub_workflows, + ), + ) self._id = id_to_register self._has_registered = True return str(id_to_register) @@ -240,7 +252,10 @@ def serialize(self): :rtype: flyteidl.admin.workflow_pb2.WorkflowSpec """ sub_workflows = self.get_sub_workflows() - return _admin_workflow_model.WorkflowSpec(self, sub_workflows,).to_flyte_idl() + return _admin_workflow_model.WorkflowSpec( + self, + sub_workflows, + ).to_flyte_idl() @_exception_scopes.system_entry_point def validate(self): @@ -255,13 +270,15 @@ def create_launch_plan(self, *args, **kwargs): if not (assumable_iam_role or kubernetes_service_account): raise _user_exceptions.FlyteValidationException("No assumable role or service account found") auth_role = _common_models.AuthRole( - assumable_iam_role=assumable_iam_role, kubernetes_service_account=kubernetes_service_account, + assumable_iam_role=assumable_iam_role, + kubernetes_service_account=kubernetes_service_account, ) return SdkLaunchPlan( workflow_id=self.id, entity_metadata=_launch_plan_models.LaunchPlanMetadata( - schedule=_schedule_models.Schedule(""), notifications=[], + schedule=_schedule_models.Schedule(""), + notifications=[], ), default_inputs=_interface_models.ParameterMap({}), fixed_inputs=_literal_models.LiteralMap(literals={}), diff --git a/flytekit/common/workflow_execution.py b/flytekit/common/workflow_execution.py index f5064cc136..14695d0e68 100644 --- a/flytekit/common/workflow_execution.py +++ b/flytekit/common/workflow_execution.py @@ -19,7 +19,9 @@ class SdkWorkflowExecution( - _execution_models.Execution, _artifact.ExecutionArtifact, metaclass=_sdk_bases.ExtendedSdkType, + _execution_models.Execution, + _artifact.ExecutionArtifact, + metaclass=_sdk_bases.ExtendedSdkType, ): def __init__(self, *args, **kwargs): super(SdkWorkflowExecution, self).__init__(*args, **kwargs) diff --git a/flytekit/contrib/notebook/tasks.py b/flytekit/contrib/notebook/tasks.py index 14ef50f668..5836717574 100644 --- a/flytekit/contrib/notebook/tasks.py +++ b/flytekit/contrib/notebook/tasks.py @@ -123,7 +123,9 @@ def __init__( _task_models.TaskMetadata( discoverable, _task_models.RuntimeMetadata( - _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, __version__, "notebook", + _task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK, + __version__, + "notebook", ), timeout, _literal_models.RetryStrategy(retries), @@ -369,7 +371,13 @@ def _get_container_definition( storage_request, cpu_request, gpu_request, memory_request, storage_limit, cpu_limit, gpu_limit, memory_limit ) - return _sdk_runnable.SdkRunnableContainer(command=[], args=[], resources=resources, env=environment, config={},) + return _sdk_runnable.SdkRunnableContainer( + command=[], + args=[], + resources=resources, + env=environment, + config={}, + ) def spark_notebook( diff --git a/flytekit/contrib/sensors/impl.py b/flytekit/contrib/sensors/impl.py index 670df86b5d..8276320922 100644 --- a/flytekit/contrib/sensors/impl.py +++ b/flytekit/contrib/sensors/impl.py @@ -96,7 +96,10 @@ def _do_poll(self): """ with self._hive_metastore_client as client: partitions = client.get_partitions_by_filter( - db_name=self._schema, tbl_name=self._table_name, filter=self._partition_filter, max_parts=1, + db_name=self._schema, + tbl_name=self._table_name, + filter=self._partition_filter, + max_parts=1, ) if partitions: return True, None diff --git a/flytekit/contrib/sensors/task.py b/flytekit/contrib/sensors/task.py index 2bacae122f..31e8a8d14e 100644 --- a/flytekit/contrib/sensors/task.py +++ b/flytekit/contrib/sensors/task.py @@ -10,7 +10,8 @@ def _execute_user_code(self, context, inputs): if sensor is not None: if not isinstance(sensor, _Sensor): raise _user_exceptions.FlyteTypeException( - received_type=type(sensor), expected_type=_Sensor, + received_type=type(sensor), + expected_type=_Sensor, ) succeeded = sensor.sense() if not succeeded: diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index 68a895d0d3..6a1bf0b7dd 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -58,7 +58,7 @@ class TaskMetadata(object): retries: for retries=n; n > 0, on failures of this task, the task will be retried at-least n number of times. timeout: the max amount of time for which one execution of this task should be executed for. If the execution will be terminated if the runtime exceeds the given timeout (approximately) - """ + """ cache: bool = False cache_version: str = "" @@ -278,7 +278,9 @@ def get_config(self, settings: SerializationSettings) -> Dict[str, str]: @abstractmethod def dispatch_execute( - self, ctx: FlyteContext, input_literal_map: _literal_models.LiteralMap, + self, + ctx: FlyteContext, + input_literal_map: _literal_models.LiteralMap, ) -> _literal_models.LiteralMap: """ This method translates Flyte's Type system based input values and invokes the actual call to the executor @@ -332,7 +334,10 @@ def __init__( a dictionary of key/value pairs """ super().__init__( - task_type=task_type, name=name, interface=transform_interface_to_typed_interface(interface), **kwargs, + task_type=task_type, + name=name, + interface=transform_interface_to_typed_interface(interface), + **kwargs, ) self._python_interface = interface if interface else Interface() self._environment = environment if environment else {} diff --git a/flytekit/core/condition.py b/flytekit/core/condition.py index 65395e975e..60d536ec88 100644 --- a/flytekit/core/condition.py +++ b/flytekit/core/condition.py @@ -300,7 +300,11 @@ def transform_to_conj_expr( left, left_promises = transform_to_boolexpr(expr.lhs) right, right_promises = transform_to_boolexpr(expr.rhs) return ( - _core_cond.ConjunctionExpression(left_expression=left, right_expression=right, operator=_logical_ops[expr.op],), + _core_cond.ConjunctionExpression( + left_expression=left, + right_expression=right, + operator=_logical_ops[expr.op], + ), merge_promises(*left_promises, *right_promises), ) diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 02df45c988..894d44fc0e 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -621,8 +621,7 @@ def node_id(self): @property def node(self) -> Node: - """ - """ + """""" return self._node def __repr__(self) -> str: diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index b58b538564..6a574e5576 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -77,7 +77,11 @@ def __init__( raise AssertionError(f"Secret {s} should be of type flytekit.Secret, received {type(s)}") sec_ctx = SecurityContext(secrets=secret_requests) super().__init__( - task_type=task_type, name=name, task_config=task_config, security_ctx=sec_ctx, **kwargs, + task_type=task_type, + name=name, + task_config=task_config, + security_ctx=sec_ctx, + **kwargs, ) self._container_image = container_image # TODO(katrogan): Implement resource overrides diff --git a/flytekit/core/reference_entity.py b/flytekit/core/reference_entity.py index 4030cf22fb..ae260a1fd9 100644 --- a/flytekit/core/reference_entity.py +++ b/flytekit/core/reference_entity.py @@ -176,7 +176,12 @@ def _local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], P return create_task_output(vals, self.python_interface) def compile(self, ctx: FlyteContext, *args, **kwargs): - return create_and_link_node(ctx, entity=self, interface=self.python_interface, **kwargs,) + return create_and_link_node( + ctx, + entity=self, + interface=self.python_interface, + **kwargs, + ) def __call__(self, *args, **kwargs): # When a Task is () aka __called__, there are three things we may do: diff --git a/flytekit/core/schedule.py b/flytekit/core/schedule.py index 2e2847d9d8..9a16a9dda6 100644 --- a/flytekit/core/schedule.py +++ b/flytekit/core/schedule.py @@ -170,13 +170,16 @@ def _translate_duration(duration: datetime.timedelta): ) elif int(duration.total_seconds()) % _SECONDS_TO_DAYS == 0: return _schedule_models.Schedule.FixedRate( - int(duration.total_seconds() / _SECONDS_TO_DAYS), _schedule_models.Schedule.FixedRateUnit.DAY, + int(duration.total_seconds() / _SECONDS_TO_DAYS), + _schedule_models.Schedule.FixedRateUnit.DAY, ) elif int(duration.total_seconds()) % _SECONDS_TO_HOURS == 0: return _schedule_models.Schedule.FixedRate( - int(duration.total_seconds() / _SECONDS_TO_HOURS), _schedule_models.Schedule.FixedRateUnit.HOUR, + int(duration.total_seconds() / _SECONDS_TO_HOURS), + _schedule_models.Schedule.FixedRateUnit.HOUR, ) else: return _schedule_models.Schedule.FixedRate( - int(duration.total_seconds() / _SECONDS_TO_MINUTES), _schedule_models.Schedule.FixedRateUnit.MINUTE, + int(duration.total_seconds() / _SECONDS_TO_MINUTES), + _schedule_models.Schedule.FixedRateUnit.MINUTE, ) diff --git a/flytekit/core/task.py b/flytekit/core/task.py index e396de10f0..c32b1e393c 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -199,7 +199,10 @@ def __init__( def reference_task( - project: str, domain: str, name: str, version: str, + project: str, + domain: str, + name: str, + version: str, ) -> Callable[[Callable[..., Any]], ReferenceTask]: """ A reference task is a pointer to a task that already exists on your Flyte installation. This diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index e2f891202e..e17137f8ae 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -423,11 +423,14 @@ def __init__(self): def _blob_type(self) -> _core_types.BlobType: return _core_types.BlobType( - format=mimetypes.types_map[".txt"], dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, + format=mimetypes.types_map[".txt"], + dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, ) def get_literal_type(self, t: typing.TextIO) -> LiteralType: - return _type_models.LiteralType(blob=self._blob_type(),) + return _type_models.LiteralType( + blob=self._blob_type(), + ) def to_literal( self, ctx: FlyteContext, python_val: typing.TextIO, python_type: Type[typing.TextIO], expected: LiteralType @@ -454,11 +457,14 @@ def __init__(self): def _blob_type(self) -> _core_types.BlobType: return _core_types.BlobType( - format=mimetypes.types_map[".bin"], dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, + format=mimetypes.types_map[".bin"], + dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, ) def get_literal_type(self, t: Type[typing.BinaryIO]) -> LiteralType: - return _type_models.LiteralType(blob=self._blob_type(),) + return _type_models.LiteralType( + blob=self._blob_type(), + ) def to_literal( self, ctx: FlyteContext, python_val: typing.BinaryIO, python_type: Type[typing.BinaryIO], expected: LiteralType @@ -484,11 +490,14 @@ def __init__(self): def _blob_type(self) -> _core_types.BlobType: return _core_types.BlobType( - format=mimetypes.types_map[".bin"], dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, + format=mimetypes.types_map[".bin"], + dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, ) def get_literal_type(self, t: Type[os.PathLike]) -> LiteralType: - return _type_models.LiteralType(blob=self._blob_type(),) + return _type_models.LiteralType( + blob=self._blob_type(), + ) def to_literal( self, ctx: FlyteContext, python_val: os.PathLike, python_type: Type[os.PathLike], expected: LiteralType @@ -591,7 +600,11 @@ def _register_default_type_transformers(): TypeEngine.register( SimpleTransformer( - "none", None, _type_models.LiteralType(simple=_type_models.SimpleType.NONE), lambda x: None, lambda x: None, + "none", + None, + _type_models.LiteralType(simple=_type_models.SimpleType.NONE), + lambda x: None, + lambda x: None, ) ) TypeEngine.register(ListTransformer()) diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index 79cc4c6ecd..69b7b27b4f 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -36,7 +36,11 @@ from flytekit.models.core import workflow as _workflow_model GLOBAL_START_NODE = Node( - id=_common_constants.GLOBAL_INPUT_NODE_ID, metadata=None, bindings=[], upstream_nodes=[], flyte_entity=None, + id=_common_constants.GLOBAL_INPUT_NODE_ID, + metadata=None, + bindings=[], + upstream_nodes=[], + flyte_entity=None, ) @@ -347,7 +351,10 @@ def _local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], P class ImperativeWorkflow(WorkflowBase): def __init__( - self, name: str, failure_policy: Optional[WorkflowFailurePolicy] = None, interruptible: Optional[bool] = False, + self, + name: str, + failure_policy: Optional[WorkflowFailurePolicy] = None, + interruptible: Optional[bool] = False, ): metadata = WorkflowMetadata(on_failure=failure_policy or WorkflowFailurePolicy.FAIL_IMMEDIATELY) workflow_metadata_defaults = WorkflowMetadataDefaults(interruptible) @@ -625,7 +632,11 @@ def compile(self, **kwargs): ) t = self.python_interface.outputs[output_names[0]] b = binding_from_python_std( - ctx, output_names[0], self.interface.outputs[output_names[0]].type, workflow_outputs, t, + ctx, + output_names[0], + self.interface.outputs[output_names[0]].type, + workflow_outputs, + t, ) bindings.append(b) elif len(output_names) > 1: @@ -637,7 +648,13 @@ def compile(self, **kwargs): if isinstance(workflow_outputs[i], ConditionalSection): raise AssertionError("A Conditional block (if-else) should always end with an `else_()` clause") t = self.python_interface.outputs[out] - b = binding_from_python_std(ctx, out, self.interface.outputs[out].type, workflow_outputs[i], t,) + b = binding_from_python_std( + ctx, + out, + self.interface.outputs[out].type, + workflow_outputs[i], + t, + ) bindings.append(b) # Save all the things necessary to create an SdkWorkflow, except for the missing project and domain @@ -712,7 +729,10 @@ def __init__( def reference_workflow( - project: str, domain: str, name: str, version: str, + project: str, + domain: str, + name: str, + version: str, ) -> Callable[[Callable[..., Any]], ReferenceWorkflow]: """ A reference workflow is a pointer to a workflow that already exists on your Flyte installation. This diff --git a/flytekit/engines/common.py b/flytekit/engines/common.py index 76731f2b8a..ed1dfd9d37 100644 --- a/flytekit/engines/common.py +++ b/flytekit/engines/common.py @@ -367,7 +367,13 @@ def fetch_latest_task(self, named_task): class EngineContext(object): def __init__( - self, execution_date, tmp_dir, stats, execution_id, logging, raw_output_data_prefix=None, + self, + execution_date, + tmp_dir, + stats, + execution_id, + logging, + raw_output_data_prefix=None, ): self._stats = stats self._execution_date = execution_date diff --git a/flytekit/engines/flyte/engine.py b/flytekit/engines/flyte/engine.py index 97ee73198d..55c20e7dbb 100644 --- a/flytekit/engines/flyte/engine.py +++ b/flytekit/engines/flyte/engine.py @@ -100,7 +100,8 @@ def get_workflow_execution(self, wf_exec): return FlyteWorkflowExecution(wf_exec) @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def fetch_workflow_execution(self, wf_exec_id): """ @@ -112,7 +113,8 @@ def fetch_workflow_execution(self, wf_exec_id): ).client.get_execution(wf_exec_id) @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def fetch_task(self, task_id): """ @@ -125,7 +127,8 @@ def fetch_task(self, task_id): ).client.get_task(task_id) @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def fetch_latest_task(self, named_task): """ @@ -136,12 +139,15 @@ def fetch_latest_task(self, named_task): task_list, _ = _FlyteClientManager( _platform_config.URL.get(), insecure=_platform_config.INSECURE.get() ).client.list_tasks_paginated( - named_task, limit=1, sort_by=_common.Sort("created_at", _common.Sort.Direction.DESCENDING), + named_task, + limit=1, + sort_by=_common.Sort("created_at", _common.Sort.Direction.DESCENDING), ) return task_list[0] if task_list else None @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def fetch_launch_plan(self, launch_plan_id): """ @@ -162,7 +168,8 @@ def fetch_launch_plan(self, launch_plan_id): ).client.get_active_launch_plan(named_entity_id) @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def fetch_workflow(self, workflow_id): """ @@ -177,7 +184,8 @@ def fetch_workflow(self, workflow_id): class FlyteLaunchPlan(_common_engine.BaseLaunchPlanLauncher): @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def register(self, identifier): client = _FlyteClientManager(_platform_config.URL.get(), insecure=_platform_config.INSECURE.get()).client @@ -201,11 +209,18 @@ def execute( Deprecated. Use launch instead. """ return self.launch( - project, domain, name, inputs, notification_overrides, label_overrides, annotation_overrides, + project, + domain, + name, + inputs, + notification_overrides, + label_overrides, + annotation_overrides, ) @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def launch( self, @@ -261,7 +276,8 @@ def launch( return client.get_execution(exec_id) @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def update(self, identifier, state): """ @@ -275,20 +291,28 @@ def update(self, identifier, state): class FlyteWorkflow(_common_engine.BaseWorkflowExecutor): @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def register(self, identifier): client = _FlyteClientManager(_platform_config.URL.get(), insecure=_platform_config.INSECURE.get()).client try: sub_workflows = self.sdk_workflow.get_sub_workflows() - return client.create_workflow(identifier, _workflow_model.WorkflowSpec(self.sdk_workflow, sub_workflows,),) + return client.create_workflow( + identifier, + _workflow_model.WorkflowSpec( + self.sdk_workflow, + sub_workflows, + ), + ) except _user_exceptions.FlyteEntityAlreadyExistsException: pass class FlyteTask(_common_engine.BaseTaskExecutor): @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def register(self, identifier): client = _FlyteClientManager(_platform_config.URL.get(), insecure=_platform_config.INSECURE.get()).client @@ -363,7 +387,9 @@ def execute(self, inputs, context=None): exc_str = _traceback.format_exc() output_file_dict[_constants.ERROR_FILE_NAME] = _error_models.ErrorDocument( _error_models.ContainerError( - "SYSTEM:Unknown", exc_str, _error_models.ContainerError.Kind.RECOVERABLE, + "SYSTEM:Unknown", + exc_str, + _error_models.ContainerError.Kind.RECOVERABLE, ) ) _logging.error(exc_str) @@ -372,11 +398,14 @@ def execute(self, inputs, context=None): for k, v in _six.iteritems(output_file_dict): _common_utils.write_proto_to_file(v.to_flyte_idl(), _os.path.join(temp_dir.name, k)) _data_proxy.Data.put_data( - temp_dir.name, context["output_prefix"], is_multipart=True, + temp_dir.name, + context["output_prefix"], + is_multipart=True, ) @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def launch( self, @@ -420,7 +449,8 @@ def launch( ) assumable_iam_role = _sdk_config.ROLE.get() auth_role = _common_models.AuthRole( - assumable_iam_role=assumable_iam_role, kubernetes_service_account=kubernetes_service_account, + assumable_iam_role=assumable_iam_role, + kubernetes_service_account=kubernetes_service_account, ) try: @@ -452,7 +482,8 @@ def launch( class FlyteWorkflowExecution(_common_engine.BaseWorkflowExecution): @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def get_node_executions(self, filters=None): """ @@ -465,7 +496,8 @@ def get_node_executions(self, filters=None): } @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def sync(self): """ @@ -475,7 +507,8 @@ def sync(self): self.sdk_workflow_execution._closure = client.get_execution(self.sdk_workflow_execution.id).closure @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def get_inputs(self): """ @@ -498,7 +531,8 @@ def get_inputs(self): return _literals.LiteralMap({}) @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def get_outputs(self): """ @@ -521,7 +555,8 @@ def get_outputs(self): return _literals.LiteralMap({}) @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def terminate(self, cause): """ @@ -534,7 +569,8 @@ def terminate(self, cause): class FlyteNodeExecution(_common_engine.BaseNodeExecution): @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def get_task_executions(self): """ @@ -544,7 +580,8 @@ def get_task_executions(self): return list(_iterate_task_executions(client, self.sdk_node_execution.id)) @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def get_subworkflow_executions(self): """ @@ -553,7 +590,8 @@ def get_subworkflow_executions(self): raise NotImplementedError("Cannot retrieve sub-workflow information from a node execution yet.") @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def get_inputs(self): """ @@ -576,7 +614,8 @@ def get_inputs(self): return _literals.LiteralMap({}) @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def get_outputs(self): """ @@ -599,7 +638,8 @@ def get_outputs(self): return _literals.LiteralMap({}) @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def sync(self): """ @@ -611,7 +651,8 @@ def sync(self): class FlyteTaskExecution(_common_engine.BaseTaskExecution): @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def get_inputs(self): """ @@ -634,7 +675,8 @@ def get_inputs(self): return _literals.LiteralMap({}) @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def get_outputs(self): """ @@ -657,7 +699,8 @@ def get_outputs(self): return _literals.LiteralMap({}) @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def sync(self): """ @@ -667,7 +710,8 @@ def sync(self): self.sdk_task_execution._closure = client.get_task_execution(self.sdk_task_execution.id).closure @_deprecated( - reason="Objects should access client directly, will be removed by 1.0", version="0.13.0", + reason="Objects should access client directly, will be removed by 1.0", + version="0.13.0", ) def get_child_executions(self, filters=None): """ @@ -678,6 +722,8 @@ def get_child_executions(self, filters=None): return { v.id.node_id: v for v in _iterate_node_executions( - client, task_execution_identifier=self.sdk_task_execution.id, filters=filters, + client, + task_execution_identifier=self.sdk_task_execution.id, + filters=filters, ) } diff --git a/flytekit/engines/unit/engine.py b/flytekit/engines/unit/engine.py index 083c8d56d9..2e3c4e187b 100644 --- a/flytekit/engines/unit/engine.py +++ b/flytekit/engines/unit/engine.py @@ -87,7 +87,8 @@ def execute(self, inputs, context=None): :rtype: dict[Text,flytekit.models.common.FlyteIdlEntity] """ with _TemporaryConfiguration( - _os.path.join(_os.path.dirname(__file__), "unit.config"), internal_overrides={"image": "unit_image"}, + _os.path.join(_os.path.dirname(__file__), "unit.config"), + internal_overrides={"image": "unit_image"}, ): with _common_utils.AutoDeletingTempDir("unit_test_dir") as working_directory: with _data_proxy.LocalWorkingDirectoryContext(working_directory): @@ -146,7 +147,8 @@ def _transform_for_user_output(self, outputs): literal_map = outputs[_sdk_constants.OUTPUT_FILE_NAME] return { name: _type_helpers.get_sdk_value_from_literal( - literal_map.literals[name], sdk_type=_type_helpers.get_sdk_type_from_literal_type(variable.type), + literal_map.literals[name], + sdk_type=_type_helpers.get_sdk_type_from_literal_type(variable.type), ).to_python_std() for name, variable in _six.iteritems(self.sdk_task.interface.outputs) } @@ -236,7 +238,11 @@ def execute_array_task(root_input_path, task, array_inputs): array_job = _array_job.ArrayJob.from_dict(task.custom) outputs = {} for job_index in _six_moves.range(0, array_job.size): - inputs_path = _os.path.join(root_input_path, _six.text_type(job_index), _sdk_constants.INPUT_FILE_NAME,) + inputs_path = _os.path.join( + root_input_path, + _six.text_type(job_index), + _sdk_constants.INPUT_FILE_NAME, + ) if inputs_path not in array_inputs: raise _system_exception.FlyteSystemAssertion( "dynamic task hasn't generated expected inputs document [{}].".format(inputs_path) diff --git a/flytekit/interfaces/data/data_proxy.py b/flytekit/interfaces/data/data_proxy.py index 18593a6099..a0babeb9ed 100644 --- a/flytekit/interfaces/data/data_proxy.py +++ b/flytekit/interfaces/data/data_proxy.py @@ -131,7 +131,10 @@ def get_data(cls, remote_path, local_path, is_multipart=False): raise _user_exception.FlyteAssertion( "Failed to get data from {remote_path} to {local_path} (recursive={is_multipart}).\n\n" "Original exception: {error_string}".format( - remote_path=remote_path, local_path=local_path, is_multipart=is_multipart, error_string=str(ex), + remote_path=remote_path, + local_path=local_path, + is_multipart=is_multipart, + error_string=str(ex), ) ) @@ -153,7 +156,10 @@ def put_data(cls, local_path, remote_path, is_multipart=False): raise _user_exception.FlyteAssertion( "Failed to put data from {local_path} to {remote_path} (recursive={is_multipart}).\n\n" "Original exception: {error_string}".format( - remote_path=remote_path, local_path=local_path, is_multipart=is_multipart, error_string=str(ex), + remote_path=remote_path, + local_path=local_path, + is_multipart=is_multipart, + error_string=str(ex), ) ) @@ -342,7 +348,10 @@ def get_data(self, remote_path: str, local_path: str, is_multipart=False): raise _user_exception.FlyteAssertion( "Failed to get data from {remote_path} to {local_path} (recursive={is_multipart}).\n\n" "Original exception: {error_string}".format( - remote_path=remote_path, local_path=local_path, is_multipart=is_multipart, error_string=str(ex), + remote_path=remote_path, + local_path=local_path, + is_multipart=is_multipart, + error_string=str(ex), ) ) diff --git a/flytekit/interfaces/data/gcs/gcs_proxy.py b/flytekit/interfaces/data/gcs/gcs_proxy.py index 4f79c02c65..c3ba29d635 100644 --- a/flytekit/interfaces/data/gcs/gcs_proxy.py +++ b/flytekit/interfaces/data/gcs/gcs_proxy.py @@ -126,7 +126,10 @@ def upload_directory(self, local_path, remote_path): GCSProxy._check_binary() cmd = self._maybe_with_gsutil_parallelism( - "cp", "-r", _amend_path(local_path), remote_path if remote_path.endswith("/") else remote_path + "/", + "cp", + "-r", + _amend_path(local_path), + remote_path if remote_path.endswith("/") else remote_path + "/", ) return _update_cmd_config_and_execute(cmd) diff --git a/flytekit/interfaces/stats/taggable.py b/flytekit/interfaces/stats/taggable.py index 32bc6f3ebf..42d1f93ac0 100644 --- a/flytekit/interfaces/stats/taggable.py +++ b/flytekit/interfaces/stats/taggable.py @@ -45,12 +45,18 @@ def extend_tags(self, tags): def pipeline(self): return TaggableStats( - self._client.pipeline(), self._full_prefix, prefix=self._scope_prefix, tags=dict(self._tags), + self._client.pipeline(), + self._full_prefix, + prefix=self._scope_prefix, + tags=dict(self._tags), ) def __enter__(self): return TaggableStats( - self._client.__enter__(), self._full_prefix, prefix=self._scope_prefix, tags=dict(self._tags), + self._client.__enter__(), + self._full_prefix, + prefix=self._scope_prefix, + tags=dict(self._tags), ) def get_stats(self, name, copy_tags=True): diff --git a/flytekit/models/admin/task_execution.py b/flytekit/models/admin/task_execution.py index 41d2e85c69..d0a6d4ed2d 100644 --- a/flytekit/models/admin/task_execution.py +++ b/flytekit/models/admin/task_execution.py @@ -7,7 +7,15 @@ class TaskExecutionClosure(_common.FlyteIdlEntity): def __init__( - self, phase, logs, started_at, duration, created_at, updated_at, output_uri=None, error=None, + self, + phase, + logs, + started_at, + duration, + created_at, + updated_at, + output_uri=None, + error=None, ): """ :param int phase: Enum value from flytekit.models.core.execution.TaskExecutionPhase diff --git a/flytekit/models/admin/workflow.py b/flytekit/models/admin/workflow.py index ce05f430ef..07ee1fa007 100644 --- a/flytekit/models/admin/workflow.py +++ b/flytekit/models/admin/workflow.py @@ -35,7 +35,8 @@ def to_flyte_idl(self): :rtype: flyteidl.admin.workflow_pb2.WorkflowSpec """ return _admin_workflow.WorkflowSpec( - template=self._template.to_flyte_idl(), sub_workflows=[s.to_flyte_idl() for s in self._sub_workflows], + template=self._template.to_flyte_idl(), + sub_workflows=[s.to_flyte_idl() for s in self._sub_workflows], ) @classmethod diff --git a/flytekit/models/array_job.py b/flytekit/models/array_job.py index 718d43603c..4e4bf99cc7 100644 --- a/flytekit/models/array_job.py +++ b/flytekit/models/array_job.py @@ -71,7 +71,11 @@ def to_dict(self): :rtype: dict[T, Text] """ return _json_format.MessageToDict( - _array_job.ArrayJob(parallelism=self.parallelism, size=self.size, min_successes=self.min_successes,) + _array_job.ArrayJob( + parallelism=self.parallelism, + size=self.size, + min_successes=self.min_successes, + ) ) @classmethod @@ -82,4 +86,8 @@ def from_dict(cls, idl_dict): """ pb2_object = _json_format.Parse(_json.dumps(idl_dict), _array_job.ArrayJob()) - return cls(parallelism=pb2_object.parallelism, size=pb2_object.size, min_successes=pb2_object.min_successes,) + return cls( + parallelism=pb2_object.parallelism, + size=pb2_object.size, + min_successes=pb2_object.min_successes, + ) diff --git a/flytekit/models/core/compiler.py b/flytekit/models/core/compiler.py index 6b6e03003c..3246ee22b3 100644 --- a/flytekit/models/core/compiler.py +++ b/flytekit/models/core/compiler.py @@ -105,7 +105,8 @@ def to_flyte_idl(self): :rtype: flyteidl.core.compiler_pb2.CompiledWorkflow """ return _compiler_pb2.CompiledWorkflow( - template=self.template.to_flyte_idl(), connections=self.connections.to_flyte_idl(), + template=self.template.to_flyte_idl(), + connections=self.connections.to_flyte_idl(), ) @classmethod diff --git a/flytekit/models/core/condition.py b/flytekit/models/core/condition.py index 54b99e6b21..845b3b4f79 100644 --- a/flytekit/models/core/condition.py +++ b/flytekit/models/core/condition.py @@ -165,7 +165,8 @@ def to_flyte_idl(self): :rtype: flyteidl.core.condition_pb2.Operand """ return _condition.Operand( - primitive=self.primitive.to_flyte_idl() if self.primitive else None, var=self.var if self.var else None, + primitive=self.primitive.to_flyte_idl() if self.primitive else None, + var=self.var if self.var else None, ) @classmethod diff --git a/flytekit/models/core/execution.py b/flytekit/models/core/execution.py index 2c20edc48b..5323e0489c 100644 --- a/flytekit/models/core/execution.py +++ b/flytekit/models/core/execution.py @@ -147,7 +147,11 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.execution_pb2.ExecutionError """ - return _execution_pb2.ExecutionError(code=self.code, message=self.message, error_uri=self.error_uri,) + return _execution_pb2.ExecutionError( + code=self.code, + message=self.message, + error_uri=self.error_uri, + ) @classmethod def from_flyte_idl(cls, p): @@ -155,7 +159,11 @@ def from_flyte_idl(cls, p): :param flyteidl.core.execution_pb2.ExecutionError p: :rtype: ExecutionError """ - return cls(code=p.code, message=p.message, error_uri=p.error_uri,) + return cls( + code=p.code, + message=p.message, + error_uri=p.error_uri, + ) class TaskLog(_common.FlyteIdlEntity): @@ -219,4 +227,9 @@ def from_flyte_idl(cls, p): :param flyteidl.core.execution_pb2.TaskLog p: :rtype: TaskLog """ - return cls(uri=p.uri, name=p.name, message_format=p.message_format, ttl=p.ttl.ToTimedelta(),) + return cls( + uri=p.uri, + name=p.name, + message_format=p.message_format, + ttl=p.ttl.ToTimedelta(), + ) diff --git a/flytekit/models/core/identifier.py b/flytekit/models/core/identifier.py index 0fa6a87594..b65f7d269d 100644 --- a/flytekit/models/core/identifier.py +++ b/flytekit/models/core/identifier.py @@ -82,7 +82,13 @@ def from_flyte_idl(cls, p): :param flyteidl.core.identifier_pb2.Identifier p: :rtype: Identifier """ - return cls(resource_type=p.resource_type, project=p.project, domain=p.domain, name=p.name, version=p.version,) + return cls( + resource_type=p.resource_type, + project=p.project, + domain=p.domain, + name=p.name, + version=p.version, + ) class WorkflowExecutionIdentifier(_common_models.FlyteIdlEntity): @@ -121,7 +127,11 @@ def to_flyte_idl(self): """ :rtype: flyteidl.core.identifier_pb2.WorkflowExecutionIdentifier """ - return _identifier_pb2.WorkflowExecutionIdentifier(project=self.project, domain=self.domain, name=self.name,) + return _identifier_pb2.WorkflowExecutionIdentifier( + project=self.project, + domain=self.domain, + name=self.name, + ) @classmethod def from_flyte_idl(cls, p): @@ -129,7 +139,11 @@ def from_flyte_idl(cls, p): :param flyteidl.core.identifier_pb2.WorkflowExecutionIdentifier p: :rtype: WorkflowExecutionIdentifier """ - return cls(project=p.project, domain=p.domain, name=p.name,) + return cls( + project=p.project, + domain=p.domain, + name=p.name, + ) class NodeExecutionIdentifier(_common_models.FlyteIdlEntity): @@ -160,7 +174,8 @@ def to_flyte_idl(self): :rtype: flyteidl.core.identifier_pb2.NodeExecutionIdentifier """ return _identifier_pb2.NodeExecutionIdentifier( - node_id=self.node_id, execution_id=self.execution_id.to_flyte_idl(), + node_id=self.node_id, + execution_id=self.execution_id.to_flyte_idl(), ) @classmethod @@ -169,7 +184,10 @@ def from_flyte_idl(cls, p): :param flyteidl.core.identifier_pb2.NodeExecutionIdentifier p: :rtype: NodeExecutionIdentifier """ - return cls(node_id=p.node_id, execution_id=WorkflowExecutionIdentifier.from_flyte_idl(p.execution_id),) + return cls( + node_id=p.node_id, + execution_id=WorkflowExecutionIdentifier.from_flyte_idl(p.execution_id), + ) class TaskExecutionIdentifier(_common_models.FlyteIdlEntity): diff --git a/flytekit/models/core/workflow.py b/flytekit/models/core/workflow.py index f15ea7b415..9bec4768a8 100644 --- a/flytekit/models/core/workflow.py +++ b/flytekit/models/core/workflow.py @@ -197,7 +197,9 @@ def to_flyte_idl(self): :rtype: flyteidl.core.workflow_pb2.NodeMetadata """ node_metadata = _core_workflow.NodeMetadata( - name=self.name, retries=self.retries.to_flyte_idl(), interruptible=self.interruptible, + name=self.name, + retries=self.retries.to_flyte_idl(), + interruptible=self.interruptible, ) if self.timeout: node_metadata.timeout.FromTimedelta(self.timeout) @@ -206,7 +208,9 @@ def to_flyte_idl(self): @classmethod def from_flyte_idl(cls, pb2_object): return cls( - pb2_object.name, pb2_object.timeout.ToTimedelta(), _RetryStrategy.from_flyte_idl(pb2_object.retries), + pb2_object.name, + pb2_object.timeout.ToTimedelta(), + _RetryStrategy.from_flyte_idl(pb2_object.retries), ) @@ -541,7 +545,14 @@ def from_flyte_idl(cls, pb2_object): class WorkflowTemplate(_common.FlyteIdlEntity): def __init__( - self, id, metadata, metadata_defaults, interface, nodes, outputs, failure_node=None, + self, + id, + metadata, + metadata_defaults, + interface, + nodes, + outputs, + failure_node=None, ): """ A workflow template encapsulates all the task, branch, and subworkflow nodes to run a statically analyzable, diff --git a/flytekit/models/execution.py b/flytekit/models/execution.py index 977495bae6..68ff7d44e8 100644 --- a/flytekit/models/execution.py +++ b/flytekit/models/execution.py @@ -62,7 +62,11 @@ def from_flyte_idl(cls, pb2_object): :param flyteidl.admin.execution_pb2.ExecutionMetadata pb2_object: :return: ExecutionMetadata """ - return cls(mode=pb2_object.mode, principal=pb2_object.principal, nesting=pb2_object.nesting,) + return cls( + mode=pb2_object.mode, + principal=pb2_object.principal, + nesting=pb2_object.nesting, + ) class ExecutionSpec(_common_models.FlyteIdlEntity): @@ -203,7 +207,8 @@ def to_flyte_idl(self): :rtype: flyteidl.admin.execution_pb2.LiteralMapBlob """ return _execution_pb2.LiteralMapBlob( - values=self.values.to_flyte_idl() if self.values is not None else None, uri=self.uri, + values=self.values.to_flyte_idl() if self.values is not None else None, + uri=self.uri, ) @classmethod @@ -256,7 +261,9 @@ def to_flyte_idl(self): :rtype: flyteidl.admin.execution_pb2.Execution """ return _execution_pb2.Execution( - id=self.id.to_flyte_idl(), closure=self.closure.to_flyte_idl(), spec=self.spec.to_flyte_idl(), + id=self.id.to_flyte_idl(), + closure=self.closure.to_flyte_idl(), + spec=self.spec.to_flyte_idl(), ) @classmethod diff --git a/flytekit/models/interface.py b/flytekit/models/interface.py index d824b6ba60..87127bc13c 100644 --- a/flytekit/models/interface.py +++ b/flytekit/models/interface.py @@ -45,7 +45,10 @@ def from_flyte_idl(cls, variable_proto): :param flyteidl.core.interface_pb2.Variable variable_proto: :rtype: Variable """ - return cls(type=_types.LiteralType.from_flyte_idl(variable_proto.type), description=variable_proto.description,) + return cls( + type=_types.LiteralType.from_flyte_idl(variable_proto.type), + description=variable_proto.description, + ) class VariableMap(_common.FlyteIdlEntity): diff --git a/flytekit/models/launch_plan.py b/flytekit/models/launch_plan.py index ba92b0f7d7..d5efac5857 100644 --- a/flytekit/models/launch_plan.py +++ b/flytekit/models/launch_plan.py @@ -370,7 +370,9 @@ def to_flyte_idl(self): else _identifier.Identifier(_identifier.ResourceType.LAUNCH_PLAN, None, None, None, None) ) return _launch_plan.LaunchPlan( - id=identifier.to_flyte_idl(), spec=self.spec.to_flyte_idl(), closure=self.closure.to_flyte_idl(), + id=identifier.to_flyte_idl(), + spec=self.spec.to_flyte_idl(), + closure=self.closure.to_flyte_idl(), ) @classmethod diff --git a/flytekit/models/literals.py b/flytekit/models/literals.py index 3b8405e106..c1f871bf21 100644 --- a/flytekit/models/literals.py +++ b/flytekit/models/literals.py @@ -45,7 +45,13 @@ def from_flyte_idl(cls, pb2_object): class Primitive(_common.FlyteIdlEntity): def __init__( - self, integer=None, float_value=None, string_value=None, boolean=None, datetime=None, duration=None, + self, + integer=None, + float_value=None, + string_value=None, + boolean=None, + datetime=None, + duration=None, ): """ This object proxies the primitives supported by the Flyte IDL system. Only one value can be set. @@ -134,7 +140,10 @@ def to_flyte_idl(self): :rtype: flyteidl.core.literals_pb2.Primitive """ primitive = _literals_pb2.Primitive( - integer=self.integer, float_value=self.float_value, string_value=self.string_value, boolean=self.boolean, + integer=self.integer, + float_value=self.float_value, + string_value=self.string_value, + boolean=self.boolean, ) if self.datetime is not None: # Convert to UTC and remove timezone so protobuf behaves. @@ -434,7 +443,8 @@ def to_literal_model(self): """ if self.promise: raise _user_exceptions.FlyteValueException( - self.promise, "Cannot convert BindingData to a Literal because " "it has a promise.", + self.promise, + "Cannot convert BindingData to a Literal because " "it has a promise.", ) elif self.scalar: return Literal(scalar=self.scalar) diff --git a/flytekit/models/matchable_resource.py b/flytekit/models/matchable_resource.py index 8b3e9d144a..64247f5bf5 100644 --- a/flytekit/models/matchable_resource.py +++ b/flytekit/models/matchable_resource.py @@ -73,7 +73,9 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.matchable_resource_pb2.ClusterResourceAttributes """ - return _matchable_resource.ClusterResourceAttributes(attributes=self.attributes,) + return _matchable_resource.ClusterResourceAttributes( + attributes=self.attributes, + ) @classmethod def from_flyte_idl(cls, pb2_object): @@ -81,7 +83,9 @@ def from_flyte_idl(cls, pb2_object): :param flyteidl.admin.matchable_resource_pb2.ClusterResourceAttributes pb2_object: :rtype: ClusterResourceAttributes """ - return cls(attributes=pb2_object.attributes,) + return cls( + attributes=pb2_object.attributes, + ) class ExecutionQueueAttributes(_common.FlyteIdlEntity): @@ -104,7 +108,9 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.matchable_resource_pb2.ExecutionQueueAttributes """ - return _matchable_resource.ExecutionQueueAttributes(tags=self.tags,) + return _matchable_resource.ExecutionQueueAttributes( + tags=self.tags, + ) @classmethod def from_flyte_idl(cls, pb2_object): @@ -112,7 +118,9 @@ def from_flyte_idl(cls, pb2_object): :param flyteidl.admin.matchable_resource_pb2.ExecutionQueueAttributes pb2_object: :rtype: ExecutionQueueAttributes """ - return cls(tags=pb2_object.tags,) + return cls( + tags=pb2_object.tags, + ) class ExecutionClusterLabel(_common.FlyteIdlEntity): @@ -135,7 +143,9 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.matchable_resource_pb2.ExecutionClusterLabel """ - return _matchable_resource.ExecutionClusterLabel(value=self.value,) + return _matchable_resource.ExecutionClusterLabel( + value=self.value, + ) @classmethod def from_flyte_idl(cls, pb2_object): @@ -143,7 +153,9 @@ def from_flyte_idl(cls, pb2_object): :param flyteidl.admin.matchable_resource_pb2.ExecutionClusterLabel pb2_object: :rtype: ExecutionClusterLabel """ - return cls(value=pb2_object.value,) + return cls( + value=pb2_object.value, + ) class PluginOverride(_common.FlyteIdlEntity): @@ -201,7 +213,9 @@ def to_flyte_idl(self): :rtype: flyteidl.admin.matchable_resource_pb2.PluginOverride """ return _matchable_resource.PluginOverride( - task_type=self.task_type, plugin_id=self.plugin_id, missing_plugin_behavior=self.missing_plugin_behavior, + task_type=self.task_type, + plugin_id=self.plugin_id, + missing_plugin_behavior=self.missing_plugin_behavior, ) @classmethod diff --git a/flytekit/models/named_entity.py b/flytekit/models/named_entity.py index 80d70aa35c..63dd598d98 100644 --- a/flytekit/models/named_entity.py +++ b/flytekit/models/named_entity.py @@ -57,7 +57,11 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.common_pb2.NamedEntityIdentifier """ - return _common.NamedEntityIdentifier(project=self.project, domain=self.domain, name=self.name,) + return _common.NamedEntityIdentifier( + project=self.project, + domain=self.domain, + name=self.name, + ) @classmethod def from_flyte_idl(cls, p): @@ -65,7 +69,11 @@ def from_flyte_idl(cls, p): :param flyteidl.core.common_pb2.NamedEntityIdentifier p: :rtype: Identifier """ - return cls(project=p.project, domain=p.domain, name=p.name,) + return cls( + project=p.project, + domain=p.domain, + name=p.name, + ) class NamedEntityMetadata(_common_models.FlyteIdlEntity): @@ -97,7 +105,10 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.common_pb2.NamedEntityMetadata """ - return _common.NamedEntityMetadata(description=self.description, state=self.state,) + return _common.NamedEntityMetadata( + description=self.description, + state=self.state, + ) @classmethod def from_flyte_idl(cls, p): @@ -105,4 +116,7 @@ def from_flyte_idl(cls, p): :param flyteidl.core.common_pb2.NamedEntityMetadata p: :rtype: Identifier """ - return cls(description=p.description, state=p.state,) + return cls( + description=p.description, + state=p.state, + ) diff --git a/flytekit/models/node_execution.py b/flytekit/models/node_execution.py index b0c103891b..4550fb6443 100644 --- a/flytekit/models/node_execution.py +++ b/flytekit/models/node_execution.py @@ -121,7 +121,9 @@ def to_flyte_idl(self): :rtype: flyteidl.admin.node_execution_pb2.NodeExecution """ return _node_execution_pb2.NodeExecution( - id=self.id.to_flyte_idl(), input_uri=self.input_uri, closure=self.closure.to_flyte_idl(), + id=self.id.to_flyte_idl(), + input_uri=self.input_uri, + closure=self.closure.to_flyte_idl(), ) @classmethod diff --git a/flytekit/models/presto.py b/flytekit/models/presto.py index 04c4b22b41..2f5f998153 100644 --- a/flytekit/models/presto.py +++ b/flytekit/models/presto.py @@ -55,7 +55,10 @@ def to_flyte_idl(self): :rtype: _presto.PrestoQuery """ return _presto.PrestoQuery( - routing_group=self._routing_group, catalog=self._catalog, schema=self._schema, statement=self._statement, + routing_group=self._routing_group, + catalog=self._catalog, + schema=self._schema, + statement=self._statement, ) @classmethod diff --git a/flytekit/models/qubole.py b/flytekit/models/qubole.py index 3464158ad3..2247d6e5fa 100644 --- a/flytekit/models/qubole.py +++ b/flytekit/models/qubole.py @@ -51,7 +51,11 @@ def from_flyte_idl(cls, pb2_object): :param _qubole.HiveQuery pb2_object: :return: HiveQuery """ - return cls(query=pb2_object.query, timeout_sec=pb2_object.timeout_sec, retry_count=pb2_object.retryCount,) + return cls( + query=pb2_object.query, + timeout_sec=pb2_object.timeout_sec, + retry_count=pb2_object.retryCount, + ) class HiveQueryCollection(_common.FlyteIdlEntity): diff --git a/flytekit/models/sagemaker/hpo_job.py b/flytekit/models/sagemaker/hpo_job.py index 966318adab..ea484f26cf 100644 --- a/flytekit/models/sagemaker/hpo_job.py +++ b/flytekit/models/sagemaker/hpo_job.py @@ -18,7 +18,9 @@ class HyperparameterTuningObjective(_common.FlyteIdlEntity): """ def __init__( - self, objective_type: int, metric_name: str, + self, + objective_type: int, + metric_name: str, ): self._objective_type = objective_type self._metric_name = metric_name @@ -44,13 +46,17 @@ def metric_name(self) -> str: def to_flyte_idl(self) -> _pb2_hpo_job.HyperparameterTuningObjective: return _pb2_hpo_job.HyperparameterTuningObjective( - objective_type=self.objective_type, metric_name=self._metric_name, + objective_type=self.objective_type, + metric_name=self._metric_name, ) @classmethod def from_flyte_idl(cls, pb2_object: _pb2_hpo_job.HyperparameterTuningObjective): - return cls(objective_type=pb2_object.objective_type, metric_name=pb2_object.metric_name,) + return cls( + objective_type=pb2_object.objective_type, + metric_name=pb2_object.metric_name, + ) class HyperparameterTuningStrategy: diff --git a/flytekit/models/sagemaker/parameter_ranges.py b/flytekit/models/sagemaker/parameter_ranges.py index e9016f2a1d..4328749d72 100644 --- a/flytekit/models/sagemaker/parameter_ranges.py +++ b/flytekit/models/sagemaker/parameter_ranges.py @@ -15,7 +15,10 @@ class HyperparameterScalingType(object): class ContinuousParameterRange(_common.FlyteIdlEntity): def __init__( - self, max_value: float, min_value: float, scaling_type: int, + self, + max_value: float, + min_value: float, + scaling_type: int, ): """ @@ -57,7 +60,9 @@ def to_flyte_idl(self) -> _idl_parameter_ranges.ContinuousParameterRange: """ return _idl_parameter_ranges.ContinuousParameterRange( - max_value=self._max_value, min_value=self._min_value, scaling_type=self.scaling_type, + max_value=self._max_value, + min_value=self._min_value, + scaling_type=self.scaling_type, ) @classmethod @@ -68,13 +73,18 @@ def from_flyte_idl(cls, pb2_object: _idl_parameter_ranges.ContinuousParameterRan :rtype: ContinuousParameterRange """ return cls( - max_value=pb2_object.max_value, min_value=pb2_object.min_value, scaling_type=pb2_object.scaling_type, + max_value=pb2_object.max_value, + min_value=pb2_object.min_value, + scaling_type=pb2_object.scaling_type, ) class IntegerParameterRange(_common.FlyteIdlEntity): def __init__( - self, max_value: int, min_value: int, scaling_type: int, + self, + max_value: int, + min_value: int, + scaling_type: int, ): """ :param int max_value: @@ -113,7 +123,9 @@ def to_flyte_idl(self) -> _idl_parameter_ranges.IntegerParameterRange: :rtype: _idl_parameter_ranges.IntegerParameterRange """ return _idl_parameter_ranges.IntegerParameterRange( - max_value=self._max_value, min_value=self._min_value, scaling_type=self.scaling_type, + max_value=self._max_value, + min_value=self._min_value, + scaling_type=self.scaling_type, ) @classmethod @@ -124,13 +136,16 @@ def from_flyte_idl(cls, pb2_object: _idl_parameter_ranges.IntegerParameterRange) :rtype: IntegerParameterRange """ return cls( - max_value=pb2_object.max_value, min_value=pb2_object.min_value, scaling_type=pb2_object.scaling_type, + max_value=pb2_object.max_value, + min_value=pb2_object.min_value, + scaling_type=pb2_object.scaling_type, ) class CategoricalParameterRange(_common.FlyteIdlEntity): def __init__( - self, values: List[str], + self, + values: List[str], ): """ @@ -163,7 +178,8 @@ def from_flyte_idl(cls, pb2_object: _idl_parameter_ranges.CategoricalParameterRa class ParameterRanges(_common.FlyteIdlEntity): def __init__( - self, parameter_range_map: Dict[str, _common.FlyteIdlEntity], + self, + parameter_range_map: Dict[str, _common.FlyteIdlEntity], ): self._parameter_range_map = parameter_range_map @@ -188,7 +204,9 @@ def to_flyte_idl(self) -> _idl_parameter_ranges.ParameterRanges: ), ) - return _idl_parameter_ranges.ParameterRanges(parameter_range_map=converted,) + return _idl_parameter_ranges.ParameterRanges( + parameter_range_map=converted, + ) @classmethod def from_flyte_idl(cls, pb2_object: _idl_parameter_ranges.ParameterRanges): @@ -206,7 +224,9 @@ def from_flyte_idl(cls, pb2_object: _idl_parameter_ranges.ParameterRanges): else: converted[k] = CategoricalParameterRange.from_flyte_idl(v.categorical_parameter_range) - return cls(parameter_range_map=converted,) + return cls( + parameter_range_map=converted, + ) class ParameterRangeOneOf(_common.FlyteIdlEntity): diff --git a/flytekit/models/sagemaker/training_job.py b/flytekit/models/sagemaker/training_job.py index faf4dd7e29..674effcbc4 100644 --- a/flytekit/models/sagemaker/training_job.py +++ b/flytekit/models/sagemaker/training_job.py @@ -96,7 +96,9 @@ def from_flyte_idl(cls, pb2_object: _training_job_pb2.TrainingJobResourceConfig) class MetricDefinition(_common.FlyteIdlEntity): def __init__( - self, name: str, regex: str, + self, + name: str, + regex: str, ): self._name = name self._regex = regex @@ -123,7 +125,10 @@ def to_flyte_idl(self) -> _training_job_pb2.MetricDefinition: :rtype: _training_job_pb2.MetricDefinition """ - return _training_job_pb2.MetricDefinition(name=self.name, regex=self.regex,) + return _training_job_pb2.MetricDefinition( + name=self.name, + regex=self.regex, + ) @classmethod def from_flyte_idl(cls, pb2_object: _training_job_pb2.MetricDefinition): @@ -132,7 +137,10 @@ def from_flyte_idl(cls, pb2_object: _training_job_pb2.MetricDefinition): :param pb2_object: _training_job_pb2.MetricDefinition :rtype: MetricDefinition """ - return cls(name=pb2_object.name, regex=pb2_object.regex,) + return cls( + name=pb2_object.name, + regex=pb2_object.regex, + ) # TODO Convert to Enum @@ -270,7 +278,9 @@ def from_flyte_idl(cls, pb2_object: _training_job_pb2.AlgorithmSpecification): class TrainingJob(_common.FlyteIdlEntity): def __init__( - self, algorithm_specification: AlgorithmSpecification, training_job_resource_config: TrainingJobResourceConfig, + self, + algorithm_specification: AlgorithmSpecification, + training_job_resource_config: TrainingJobResourceConfig, ): self._algorithm_specification = algorithm_specification self._training_job_resource_config = training_job_resource_config diff --git a/flytekit/models/security.py b/flytekit/models/security.py index dbe5196a7d..e4ea655e22 100644 --- a/flytekit/models/security.py +++ b/flytekit/models/security.py @@ -68,11 +68,17 @@ class OAuth2Client(_common.FlyteIdlEntity): client_secret: str def to_flyte_idl(self) -> _sec.OAuth2Client: - return _sec.OAuth2Client(client_id=self.client_id, client_secret=self.client_secret,) + return _sec.OAuth2Client( + client_id=self.client_id, + client_secret=self.client_secret, + ) @classmethod def from_flyte_idl(cls, pb2_object: _sec.OAuth2Client) -> "OAuth2Client": - return cls(client_id=pb2_object.client_id, client_secret=pb2_object.client_secret,) + return cls( + client_id=pb2_object.client_id, + client_secret=pb2_object.client_secret, + ) @dataclass diff --git a/flytekit/models/task.py b/flytekit/models/task.py index 3f517d0aee..8e34b90cf7 100644 --- a/flytekit/models/task.py +++ b/flytekit/models/task.py @@ -100,7 +100,8 @@ def to_flyte_idl(self): :rtype: flyteidl.core.tasks_pb2.Resources """ return _core_task.Resources( - requests=[r.to_flyte_idl() for r in self.requests], limits=[r.to_flyte_idl() for r in self.limits], + requests=[r.to_flyte_idl() for r in self.requests], + limits=[r.to_flyte_idl() for r in self.limits], ) @classmethod @@ -172,7 +173,14 @@ def from_flyte_idl(cls, pb2_object): class TaskMetadata(_common.FlyteIdlEntity): def __init__( - self, discoverable, runtime, timeout, retries, interruptible, discovery_version, deprecated_error_message, + self, + discoverable, + runtime, + timeout, + retries, + interruptible, + discovery_version, + deprecated_error_message, ): """ Information needed at runtime to determine behavior such as whether or not outputs are discoverable, timeouts, @@ -496,7 +504,10 @@ def to_flyte_idl(self): """ :rtype: flyteidl.admin.task_pb2.Task """ - return _admin_task.Task(closure=self.closure.to_flyte_idl(), id=self.id.to_flyte_idl(),) + return _admin_task.Task( + closure=self.closure.to_flyte_idl(), + id=self.id.to_flyte_idl(), + ) @classmethod def from_flyte_idl(cls, pb2_object): @@ -570,7 +581,13 @@ def from_flyte_idl(cls, pb2_object): class SparkJob(_common.FlyteIdlEntity): def __init__( - self, spark_type, application_file, main_class, spark_conf, hadoop_conf, executor_path, + self, + spark_type, + application_file, + main_class, + spark_conf, + hadoop_conf, + executor_path, ): """ This defines a SparkJob target. It will execute the appropriate SparkJob. @@ -730,7 +747,10 @@ def to_flyte_idl(self) -> _core_task.IOStrategy: def from_flyte_idl(cls, pb2_object: _core_task.IOStrategy): if pb2_object is None: return None - return cls(download_mode=pb2_object.download_mode, upload_mode=pb2_object.upload_mode,) + return cls( + download_mode=pb2_object.download_mode, + upload_mode=pb2_object.upload_mode, + ) class DataLoadingConfig(_common.FlyteIdlEntity): @@ -931,7 +951,10 @@ def from_flyte_idl(cls, pb2_object): :param flyteidl.admin.task_pb2.Task pb2_object: :rtype: Container """ - return cls(pod_spec=pb2_object.pod_spec, primary_container_name=pb2_object.primary_container_name,) + return cls( + pod_spec=pb2_object.pod_spec, + primary_container_name=pb2_object.primary_container_name, + ) class PyTorchJob(_common.FlyteIdlEntity): @@ -943,11 +966,15 @@ def workers_count(self): return self._workers_count def to_flyte_idl(self): - return _pytorch_task.DistributedPyTorchTrainingTask(workers=self.workers_count,) + return _pytorch_task.DistributedPyTorchTrainingTask( + workers=self.workers_count, + ) @classmethod def from_flyte_idl(cls, pb2_object): - return cls(workers_count=pb2_object.workers,) + return cls( + workers_count=pb2_object.workers, + ) class TensorFlowJob(_common.FlyteIdlEntity): diff --git a/flytekit/models/types.py b/flytekit/models/types.py index 7d68645c7a..69cfb75dd2 100644 --- a/flytekit/models/types.py +++ b/flytekit/models/types.py @@ -100,7 +100,13 @@ def from_flyte_idl(cls, proto): class LiteralType(_common.FlyteIdlEntity): def __init__( - self, simple=None, schema=None, collection_type=None, map_value_type=None, blob=None, metadata=None, + self, + simple=None, + schema=None, + collection_type=None, + map_value_type=None, + blob=None, + metadata=None, ): """ Only one of the kwargs may be set. @@ -258,7 +264,10 @@ def __init__(self, failed_node_id: str, message: str): self._failed_node_id = failed_node_id def to_flyte_idl(self) -> _types_pb2.Error: - return _types_pb2.Error(message=self._message, failed_node_id=self._failed_node_id,) + return _types_pb2.Error( + message=self._message, + failed_node_id=self._failed_node_id, + ) @classmethod def from_flyte_idl(cls, pb2_object: _types_pb2.Error) -> "Error": diff --git a/flytekit/models/workflow_closure.py b/flytekit/models/workflow_closure.py index 412a52e958..fbf0b08688 100644 --- a/flytekit/models/workflow_closure.py +++ b/flytekit/models/workflow_closure.py @@ -33,7 +33,8 @@ def to_flyte_idl(self): :rtype: flyteidl.core.workflow_closure_pb2.WorkflowClosure """ return _workflow_closure_pb2.WorkflowClosure( - workflow=self.workflow.to_flyte_idl(), tasks=[t.to_flyte_idl() for t in self.tasks], + workflow=self.workflow.to_flyte_idl(), + tasks=[t.to_flyte_idl() for t in self.tasks], ) @classmethod diff --git a/flytekit/plugins/__init__.py b/flytekit/plugins/__init__.py index 5c928bbb60..e20bf5eb36 100644 --- a/flytekit/plugins/__init__.py +++ b/flytekit/plugins/__init__.py @@ -30,7 +30,9 @@ _lazy_loader.LazyLoadPlugin("sidecar", ["k8s-proto>=0.0.3,<1.0.0"], [k8s, flyteidl]) _lazy_loader.LazyLoadPlugin( - "schema", ["numpy>=1.14.0,<2.0.0", "pandas>=0.22.0,<2.0.0", "pyarrow>=0.11.0,<1.0.0"], [numpy, pandas], + "schema", + ["numpy>=1.14.0,<2.0.0", "pandas>=0.22.0,<2.0.0", "pyarrow>=0.11.0,<1.0.0"], + [numpy, pandas], ) _lazy_loader.LazyLoadPlugin("hive_sensor", ["hmsclient>=0.0.1,<1.0.0"], [hmsclient]) diff --git a/flytekit/sdk/tasks.py b/flytekit/sdk/tasks.py index 7ecd5682ab..1e6dcd5e4c 100644 --- a/flytekit/sdk/tasks.py +++ b/flytekit/sdk/tasks.py @@ -44,8 +44,11 @@ def my_task(wf_params, in1, in2, out1, out2): def apply_inputs_wrapper(task): if not isinstance(task, _task.SdkTask): - additional_msg = "Inputs can only be applied to a task. Did you forget the task decorator on method '{}.{}'?".format( - task.__module__, task.__name__ if hasattr(task, "__name__") else "", + additional_msg = ( + "Inputs can only be applied to a task. Did you forget the task decorator on method '{}.{}'?".format( + task.__module__, + task.__name__ if hasattr(task, "__name__") else "", + ) ) raise _user_exceptions.FlyteTypeException( expected_type=_sdk_runnable_tasks.SdkRunnableTask, @@ -94,8 +97,11 @@ def apply_outputs_wrapper(task): if not isinstance(task, _sdk_runnable_tasks.SdkRunnableTask) and not isinstance( task, _nb_tasks.SdkNotebookTask ): - additional_msg = "Outputs can only be applied to a task. Did you forget the task decorator on method '{}.{}'?".format( - task.__module__, task.__name__ if hasattr(task, "__name__") else "", + additional_msg = ( + "Outputs can only be applied to a task. Did you forget the task decorator on method '{}.{}'?".format( + task.__module__, + task.__name__ if hasattr(task, "__name__") else "", + ) ) raise _user_exceptions.FlyteTypeException( expected_type=_sdk_runnable_tasks.SdkRunnableTask, diff --git a/flytekit/sdk/workflow.py b/flytekit/sdk/workflow.py index d4cf8aedea..34561dc592 100644 --- a/flytekit/sdk/workflow.py +++ b/flytekit/sdk/workflow.py @@ -40,7 +40,10 @@ def __init__(self, value, sdk_type=None, help=None): this value be provided as the SDK might not always be able to infer the correct type. """ super(Output, self).__init__( - "", value, sdk_type=_type_helpers.python_std_to_sdk_type(sdk_type) if sdk_type else None, help=help, + "", + value, + sdk_type=_type_helpers.python_std_to_sdk_type(sdk_type) if sdk_type else None, + help=help, ) diff --git a/flytekit/tools/fast_registration.py b/flytekit/tools/fast_registration.py index 417cfee5c4..00ef36855e 100644 --- a/flytekit/tools/fast_registration.py +++ b/flytekit/tools/fast_registration.py @@ -105,6 +105,7 @@ def download_distribution(additional_distribution: str, destination: str): # This will overwrite the existing user flyte workflow code in the current working code dir. result = _subprocess.run( - ["tar", "-xvf", _os.path.join(destination, tarfile_name), "-C", destination], stdout=_subprocess.PIPE, + ["tar", "-xvf", _os.path.join(destination, tarfile_name), "-C", destination], + stdout=_subprocess.PIPE, ) result.check_returncode() diff --git a/flytekit/tools/module_loader.py b/flytekit/tools/module_loader.py index 7983ecae8b..7149e532af 100644 --- a/flytekit/tools/module_loader.py +++ b/flytekit/tools/module_loader.py @@ -139,7 +139,11 @@ def iterate(): def iterate_registerable_entities_in_order( - pkgs, local_source_root=None, ignore_entities=None, include_entities=None, detect_unreferenced_entities=True, + pkgs, + local_source_root=None, + ignore_entities=None, + include_entities=None, + detect_unreferenced_entities=True, ): """ This function will iterate all discovered entities in the given package list. It will then attempt to diff --git a/flytekit/type_engines/default/flyte.py b/flytekit/type_engines/default/flyte.py index b930a63d97..0ec3a4c982 100644 --- a/flytekit/type_engines/default/flyte.py +++ b/flytekit/type_engines/default/flyte.py @@ -21,7 +21,8 @@ def _load_type_from_tag(tag: str) -> Type: if "." not in tag: raise _user_exceptions.FlyteValueException( - tag, "Protobuf tag must include at least one '.' to delineate package and object name.", + tag, + "Protobuf tag must include at least one '.' to delineate package and object name.", ) module, name = tag.rsplit(".", 1) diff --git a/plugins/awssagemaker/flytekitplugins/awssagemaker/training.py b/plugins/awssagemaker/flytekitplugins/awssagemaker/training.py index 392ebebb92..8ff51b1d15 100644 --- a/plugins/awssagemaker/flytekitplugins/awssagemaker/training.py +++ b/plugins/awssagemaker/flytekitplugins/awssagemaker/training.py @@ -49,7 +49,10 @@ class SagemakerBuiltinAlgorithmsTask(PythonTask[SagemakerTrainingJobConfig]): OUTPUT_TYPE = TypeVar("tar.gz") def __init__( - self, name: str, task_config: SagemakerTrainingJobConfig, **kwargs, + self, + name: str, + task_config: SagemakerTrainingJobConfig, + **kwargs, ): """ Args: @@ -75,7 +78,11 @@ def __init__( outputs=kwtypes(model=FlyteFile[self.OUTPUT_TYPE]), ) super().__init__( - self._SAGEMAKER_TRAINING_JOB_TASK, name, interface=interface, task_config=task_config, **kwargs, + self._SAGEMAKER_TRAINING_JOB_TASK, + name, + interface=interface, + task_config=task_config, + **kwargs, ) def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: @@ -110,7 +117,10 @@ class SagemakerCustomTrainingTask(PythonFunctionTask[SagemakerTrainingJobConfig] _SAGEMAKER_CUSTOM_TRAINING_JOB_TASK = "sagemaker_custom_training_job_task" def __init__( - self, task_config: SagemakerTrainingJobConfig, task_function: Callable, **kwargs, + self, + task_config: SagemakerTrainingJobConfig, + task_function: Callable, + **kwargs, ): super().__init__( task_config=task_config, @@ -145,7 +155,7 @@ def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: exec_state = FlyteContext.current_context().execution_state if exec_state and exec_state.mode == ExecutionState.Mode.TASK_EXECUTION: """ - This mode indicates we are actually in a remote execute environment (within sagemaker in this case) + This mode indicates we are actually in a remote execute environment (within sagemaker in this case) """ dist_ctx = DistributedTrainingContext.from_env() else: diff --git a/plugins/hive/flytekitplugins/hive/task.py b/plugins/hive/flytekitplugins/hive/task.py index 31c08fd46f..9552268265 100644 --- a/plugins/hive/flytekitplugins/hive/task.py +++ b/plugins/hive/flytekitplugins/hive/task.py @@ -87,7 +87,11 @@ def tags(self) -> List[str]: def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: # timeout_sec and retry_count will become deprecated, please use timeout and retry settings on the Task query = HiveQuery(query=self.query_template, timeout_sec=0, retry_count=0) - job = QuboleHiveJob(query=query, cluster_label=self.cluster_label, tags=self.tags,) + job = QuboleHiveJob( + query=query, + cluster_label=self.cluster_label, + tags=self.tags, + ) return MessageToDict(job.to_flyte_idl()) @@ -116,10 +120,10 @@ def __init__( **kwargs, ): """ - Args: - select_query: Singular query that returns a Tabular dataset - stage_query: optional query that should be executed before the actual ``select_query``. This can usually - be used for setting memory or the an alternate execution engine like :ref:`tez`_/ + Args: + select_query: Singular query that returns a Tabular dataset + stage_query: optional query that should be executed before the actual ``select_query``. This can usually + be used for setting memory or the an alternate execution engine like :ref:`tez`_/ """ query_template = HiveSelectTask._HIVE_QUERY_FORMATTER.format( stage_query_str=stage_query or "", select_query_str=select_query.strip().strip(";") diff --git a/plugins/papermill/flytekitplugins/papermill/task.py b/plugins/papermill/flytekitplugins/papermill/task.py index 197681cdc6..c8ffa825ab 100644 --- a/plugins/papermill/flytekitplugins/papermill/task.py +++ b/plugins/papermill/flytekitplugins/papermill/task.py @@ -147,8 +147,8 @@ def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters: @staticmethod def extract_outputs(nb: str) -> LiteralMap: """ - Parse Outputs from Notebook. - This looks for a cell, with the tag "outputs" to be present. + Parse Outputs from Notebook. + This looks for a cell, with the tag "outputs" to be present. """ with open(nb) as json_file: data = json.load(json_file) @@ -164,9 +164,9 @@ def extract_outputs(nb: str) -> LiteralMap: @staticmethod def render_nb_html(from_nb: str, to: str): """ - render output notebook to html - We are using nbconvert htmlexporter and its classic template - later about how to customize the exporter further. + render output notebook to html + We are using nbconvert htmlexporter and its classic template + later about how to customize the exporter further. """ html_exporter = HTMLExporter() html_exporter.template_name = "classic" @@ -213,10 +213,10 @@ def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any: def record_outputs(**kwargs) -> str: """ - Use this method to record outputs from a notebook. - It will convert all outputs to a Flyte understandable format. For Files, Directories, please use FlyteFile or - FlyteDirectory, or wrap up your paths in these decorators. - """ + Use this method to record outputs from a notebook. + It will convert all outputs to a Flyte understandable format. For Files, Directories, please use FlyteFile or + FlyteDirectory, or wrap up your paths in these decorators. + """ if kwargs is None: return "" diff --git a/plugins/pod/flytekitplugins/pod/task.py b/plugins/pod/flytekitplugins/pod/task.py index 5449b7df5b..9733d5596c 100644 --- a/plugins/pod/flytekitplugins/pod/task.py +++ b/plugins/pod/flytekitplugins/pod/task.py @@ -33,7 +33,11 @@ def primary_container_name(self) -> str: class PodFunctionTask(PythonFunctionTask[Pod]): def __init__(self, task_config: Pod, task_function: Callable, **kwargs): super(PodFunctionTask, self).__init__( - task_config=task_config, task_type="sidecar", task_function=task_function, task_type_version=1, **kwargs, + task_config=task_config, + task_type="sidecar", + task_function=task_function, + task_type_version=1, + **kwargs, ) def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: diff --git a/plugins/spark/flytekitplugins/spark/task.py b/plugins/spark/flytekitplugins/spark/task.py index 7d7253f149..2415392376 100644 --- a/plugins/spark/flytekitplugins/spark/task.py +++ b/plugins/spark/flytekitplugins/spark/task.py @@ -70,7 +70,10 @@ class PysparkFunctionTask(PythonFunctionTask[Spark]): def __init__(self, task_config: Spark, task_function: Callable, **kwargs): super(PysparkFunctionTask, self).__init__( - task_config=task_config, task_type=self._SPARK_TASK_TYPE, task_function=task_function, **kwargs, + task_config=task_config, + task_type=self._SPARK_TASK_TYPE, + task_function=task_function, + **kwargs, ) def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: diff --git a/plugins/tests/awssagemaker/test_hpo.py b/plugins/tests/awssagemaker/test_hpo.py index 0762109e51..741f952d8b 100644 --- a/plugins/tests/awssagemaker/test_hpo.py +++ b/plugins/tests/awssagemaker/test_hpo.py @@ -25,13 +25,21 @@ def test_hpo_for_builtin(): name="builtin-trainer", task_config=SagemakerTrainingJobConfig( training_job_resource_config=TrainingJobResourceConfig( - instance_count=1, instance_type="ml-xlarge", volume_size_in_gb=1, + instance_count=1, + instance_type="ml-xlarge", + volume_size_in_gb=1, + ), + algorithm_specification=AlgorithmSpecification( + algorithm_name=AlgorithmName.XGBOOST, ), - algorithm_specification=AlgorithmSpecification(algorithm_name=AlgorithmName.XGBOOST,), ), ) - hpo = SagemakerHPOTask(name="test", task_config=HPOJob(10, 10, ["x"]), training_task=trainer,) + hpo = SagemakerHPOTask( + name="test", + task_config=HPOJob(10, 10, ["x"]), + training_task=trainer, + ) assert hpo.python_interface.inputs.keys() == { "static_hyperparameters", @@ -59,7 +67,8 @@ def test_hpo_for_builtin(): hyperparameter_tuning_job_config=HyperparameterTuningJobConfig( tuning_strategy=1, tuning_objective=HyperparameterTuningObjective( - objective_type=HyperparameterTuningObjectiveType.MINIMIZE, metric_name="x", + objective_type=HyperparameterTuningObjectiveType.MINIMIZE, + metric_name="x", ), training_job_early_stopping_type=TrainingJobEarlyStoppingType.OFF, ), @@ -73,7 +82,8 @@ def test_hpoconfig_transformer(): o = HyperparameterTuningJobConfig( tuning_strategy=1, tuning_objective=HyperparameterTuningObjective( - objective_type=HyperparameterTuningObjectiveType.MINIMIZE, metric_name="x", + objective_type=HyperparameterTuningObjectiveType.MINIMIZE, + metric_name="x", ), training_job_early_stopping_type=TrainingJobEarlyStoppingType.OFF, ) diff --git a/plugins/tests/awssagemaker/test_training.py b/plugins/tests/awssagemaker/test_training.py index 22ff1cb104..b2a6a14ec3 100644 --- a/plugins/tests/awssagemaker/test_training.py +++ b/plugins/tests/awssagemaker/test_training.py @@ -34,9 +34,13 @@ def test_builtin_training(): name="builtin-trainer", task_config=SagemakerTrainingJobConfig( training_job_resource_config=TrainingJobResourceConfig( - instance_count=1, instance_type="ml-xlarge", volume_size_in_gb=1, + instance_count=1, + instance_type="ml-xlarge", + volume_size_in_gb=1, + ), + algorithm_specification=AlgorithmSpecification( + algorithm_name=AlgorithmName.XGBOOST, ), - algorithm_specification=AlgorithmSpecification(algorithm_name=AlgorithmName.XGBOOST,), ), ) @@ -62,8 +66,13 @@ def test_builtin_training(): def test_custom_training(): @task( task_config=SagemakerTrainingJobConfig( - training_job_resource_config=TrainingJobResourceConfig(instance_type="ml-xlarge", volume_size_in_gb=1,), - algorithm_specification=AlgorithmSpecification(algorithm_name=AlgorithmName.CUSTOM,), + training_job_resource_config=TrainingJobResourceConfig( + instance_type="ml-xlarge", + volume_size_in_gb=1, + ), + algorithm_specification=AlgorithmSpecification( + algorithm_name=AlgorithmName.CUSTOM, + ), ) ) def my_custom_trainer(x: int) -> int: @@ -91,7 +100,9 @@ def test_distributed_custom_training(): instance_count=2, # Indicates distributed training distributed_protocol=DistributedProtocol.MPI, ), - algorithm_specification=AlgorithmSpecification(algorithm_name=AlgorithmName.CUSTOM,), + algorithm_specification=AlgorithmSpecification( + algorithm_name=AlgorithmName.CUSTOM, + ), ) ) def my_custom_trainer(x: int) -> int: diff --git a/plugins/tests/pod/test_pod.py b/plugins/tests/pod/test_pod.py index e4d86984d3..59adc0313d 100644 --- a/plugins/tests/pod/test_pod.py +++ b/plugins/tests/pod/test_pod.py @@ -14,9 +14,16 @@ def get_pod_spec(): - a_container = V1Container(name="a container",) + a_container = V1Container( + name="a container", + ) a_container.command = ["fee", "fi", "fo", "fum"] - a_container.volume_mounts = [V1VolumeMount(name="volume mount", mount_path="some/where",)] + a_container.volume_mounts = [ + V1VolumeMount( + name="volume mount", + mount_path="some/where", + ) + ] pod_spec = V1PodSpec(restart_policy="OnFailure", containers=[a_container, V1Container(name="another container")]) return pod_spec diff --git a/pyproject.toml b/pyproject.toml index 339d113e41..f7e217c339 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,16 +1,14 @@ [tool.black] line-length = 120 -exclude = ''' -( - /( - \.eggs # exclude a few common directories in the - | \.git # root of the project - | \.hg - | \.mypy_cache - | \.tox - | _build - | build - | dist - )/ -) -''' + +[tool.isort] +profile = "black" +line_length = 120 + +[tool.pytest.ini_options] +norecursedirs = ["common", "workflows", "spark"] +log_cli = true +log_cli_level = 20 + +[tool.coverage.run] +branch = true diff --git a/requirements-spark3.txt b/requirements-spark3.txt index 4b0f8eff11..69aec3fb38 100644 --- a/requirements-spark3.txt +++ b/requirements-spark3.txt @@ -10,30 +10,23 @@ ansiwrap==0.8.4 # via papermill appdirs==1.4.4 # via black -appnope==0.1.2 - # via - # ipykernel - # ipython async-generator==1.10 # via nbclient attrs==20.3.0 # via - # black # jsonschema # scantree backcall==0.2.0 # via ipython bcrypt==3.2.0 # via paramiko -black==19.10b0 - # via - # flytekit - # papermill +black==20.8b1 + # via papermill bleach==3.3.0 # via nbconvert -boto3==1.17.36 +boto3==1.17.39 # via sagemaker-training -botocore==1.20.36 +botocore==1.20.39 # via # boto3 # s3transfer @@ -54,8 +47,10 @@ click==7.1.2 # papermill croniter==1.0.10 # via flytekit -cryptography==3.4.6 - # via paramiko +cryptography==3.4.7 + # via + # paramiko + # secretstorage dataclasses-json==0.5.2 # via flytekit decorator==4.4.2 @@ -100,6 +95,10 @@ ipython==7.21.0 # via ipykernel jedi==0.18.0 # via ipython +jeepney==0.6.0 + # via + # keyring + # secretstorage jinja2==2.11.3 # via nbconvert jmespath==0.10.0 @@ -134,7 +133,9 @@ marshmallow==3.10.0 mistune==0.8.4 # via nbconvert mypy-extensions==0.4.3 - # via typing-inspect + # via + # black + # typing-inspect natsort==7.1.1 # via flytekit nbclient==0.5.3 @@ -250,6 +251,8 @@ scantree==0.0.1 # via dirhash scipy==1.6.2 # via sagemaker-training +secretstorage==3.3.1 + # via keyring six==1.15.0 # via # bcrypt @@ -300,7 +303,9 @@ traitlets==5.0.5 typed-ast==1.4.2 # via black typing-extensions==3.7.4.3 - # via typing-inspect + # via + # black + # typing-inspect typing-inspect==0.6.0 # via dataclasses-json urllib3==1.25.11 diff --git a/requirements.txt b/requirements.txt index 6820125a79..866325c7c7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,30 +10,23 @@ ansiwrap==0.8.4 # via papermill appdirs==1.4.4 # via black -appnope==0.1.2 - # via - # ipykernel - # ipython async-generator==1.10 # via nbclient attrs==20.3.0 # via - # black # jsonschema # scantree backcall==0.2.0 # via ipython bcrypt==3.2.0 # via paramiko -black==19.10b0 - # via - # flytekit - # papermill +black==20.8b1 + # via papermill bleach==3.3.0 # via nbconvert -boto3==1.17.36 +boto3==1.17.39 # via sagemaker-training -botocore==1.20.36 +botocore==1.20.39 # via # boto3 # s3transfer @@ -54,8 +47,10 @@ click==7.1.2 # papermill croniter==1.0.10 # via flytekit -cryptography==3.4.6 - # via paramiko +cryptography==3.4.7 + # via + # paramiko + # secretstorage dataclasses-json==0.5.2 # via flytekit decorator==4.4.2 @@ -100,6 +95,10 @@ ipython==7.21.0 # via ipykernel jedi==0.18.0 # via ipython +jeepney==0.6.0 + # via + # keyring + # secretstorage jinja2==2.11.3 # via nbconvert jmespath==0.10.0 @@ -134,7 +133,9 @@ marshmallow==3.10.0 mistune==0.8.4 # via nbconvert mypy-extensions==0.4.3 - # via typing-inspect + # via + # black + # typing-inspect natsort==7.1.1 # via flytekit nbclient==0.5.3 @@ -250,6 +251,8 @@ scantree==0.0.1 # via dirhash scipy==1.6.2 # via sagemaker-training +secretstorage==3.3.1 + # via keyring six==1.15.0 # via # bcrypt @@ -300,7 +303,9 @@ traitlets==5.0.5 typed-ast==1.4.2 # via black typing-extensions==3.7.4.3 - # via typing-inspect + # via + # black + # typing-inspect typing-inspect==0.6.0 # via dataclasses-json urllib3==1.25.11 diff --git a/setup.cfg b/setup.cfg index d2d5096bfb..c5c8bbd7a1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,11 +1,3 @@ -[isort] -multi_line_output = 3 -include_trailing_comma = True -force_grid_wrap = 0 -use_parentheses = True -ensure_newline_before_comments = True -line_length = 120 - [flake8] max-line-length = 120 extend-ignore = E203, E266, E501, W503, E741 @@ -21,14 +13,5 @@ ignore_missing_imports = True follow_imports = skip cache_dir = /dev/null -[tool:pytest] -norecursedirs = common workflows spark -log_cli = true -log_cli_level = 20 - -[coverage:run] -branch = True - [metadata] license_files = LICENSE - diff --git a/setup.py b/setup.py index bf2e9e08b3..0ea2a69fa8 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ sidecar = ["k8s-proto>=0.0.3,<1.0.0"] schema = ["numpy>=1.14.0,<2.0.0", "pandas>=0.22.0,<2.0.0", "pyarrow>2.0.0,<4.0.0"] hive_sensor = ["hmsclient>=0.0.1,<1.0.0"] -notebook = ["papermill>=1.2.0", "nbconvert>=6.0.7", "ipykernel>=5.0.0", "black==19.10b0"] +notebook = ["papermill>=1.2.0", "nbconvert>=6.0.7", "ipykernel>=5.0.0"] sagemaker = ["sagemaker-training>=3.6.2,<4.0.0"] all_but_spark = sidecar + schema + hive_sensor + notebook + sagemaker diff --git a/tests/flytekit/common/parameterizers.py b/tests/flytekit/common/parameterizers.py index b2b90c9e81..b79ec97ca3 100644 --- a/tests/flytekit/common/parameterizers.py +++ b/tests/flytekit/common/parameterizers.py @@ -32,16 +32,28 @@ ) ), types.LiteralType( - blob=_core_types.BlobType(format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE,) + blob=_core_types.BlobType( + format="", + dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, + ) ), types.LiteralType( - blob=_core_types.BlobType(format="csv", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE,) + blob=_core_types.BlobType( + format="csv", + dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, + ) ), types.LiteralType( - blob=_core_types.BlobType(format="", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART,) + blob=_core_types.BlobType( + format="", + dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, + ) ), types.LiteralType( - blob=_core_types.BlobType(format="csv", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART,) + blob=_core_types.BlobType( + format="csv", + dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, + ) ), ] @@ -60,7 +72,8 @@ LIST_OF_INTERFACES = [ interface.TypedInterface( - {"a": interface.Variable(t, "description 1")}, {"b": interface.Variable(t, "description 2")}, + {"a": interface.Variable(t, "description 1")}, + {"b": interface.Variable(t, "description 2")}, ) for t in LIST_OF_ALL_LITERAL_TYPES ] @@ -95,7 +108,13 @@ LIST_OF_TASK_METADATA = [ task.TaskMetadata( - discoverable, runtime_metadata, timeout, retry_strategy, interruptible, discovery_version, deprecated, + discoverable, + runtime_metadata, + timeout, + retry_strategy, + interruptible, + discovery_version, + deprecated, ) for discoverable, runtime_metadata, timeout, retry_strategy, interruptible, discovery_version, deprecated in product( [True, False], @@ -117,7 +136,12 @@ interfaces, {"a": 1, "b": [1, 2, 3], "c": "abc", "d": {"x": 1, "y": 2, "z": 3}}, container=task.Container( - "my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], resources, {"a": "b"}, {"d": "e"}, + "my_image", + ["this", "is", "a", "cmd"], + ["this", "is", "an", "arg"], + resources, + {"a": "b"}, + {"d": "e"}, ), ) for task_metadata, interfaces, resources in product(LIST_OF_TASK_METADATA, LIST_OF_INTERFACES, LIST_OF_RESOURCES) @@ -125,7 +149,12 @@ LIST_OF_CONTAINERS = [ task.Container( - "my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], resources, {"a": "b"}, {"d": "e"}, + "my_image", + ["this", "is", "a", "cmd"], + ["this", "is", "an", "arg"], + resources, + {"a": "b"}, + {"d": "e"}, ) for resources in LIST_OF_RESOURCES ] @@ -137,7 +166,10 @@ (literals.Scalar(primitive=literals.Primitive(float_value=500.0)), 500.0), (literals.Scalar(primitive=literals.Primitive(boolean=True)), True), (literals.Scalar(primitive=literals.Primitive(string_value="hello")), "hello"), - (literals.Scalar(primitive=literals.Primitive(duration=timedelta(seconds=5))), timedelta(seconds=5),), + ( + literals.Scalar(primitive=literals.Primitive(duration=timedelta(seconds=5))), + timedelta(seconds=5), + ), (literals.Scalar(none_type=literals.Void()), None), ( literals.Scalar( diff --git a/tests/flytekit/common/workflows/python.py b/tests/flytekit/common/workflows/python.py index ff8f41f29d..a0d423b86c 100644 --- a/tests/flytekit/common/workflows/python.py +++ b/tests/flytekit/common/workflows/python.py @@ -29,7 +29,10 @@ def sum_non_none(workflow_parameters, value1_to_print, value2_to_print, out): @inputs( - value1_to_add=Types.Integer, value2_to_add=Types.Integer, value3_to_add=Types.Integer, value4_to_add=Types.Integer, + value1_to_add=Types.Integer, + value2_to_add=Types.Integer, + value3_to_add=Types.Integer, + value4_to_add=Types.Integer, ) @outputs(out=Types.Integer) @python_task(cache_version="1") diff --git a/tests/flytekit/common/workflows/sagemaker.py b/tests/flytekit/common/workflows/sagemaker.py index b0fe7d36d6..044bba29f2 100644 --- a/tests/flytekit/common/workflows/sagemaker.py +++ b/tests/flytekit/common/workflows/sagemaker.py @@ -54,7 +54,9 @@ builtin_algorithm_training_job_task2 = SdkBuiltinAlgorithmTrainingJobTask( training_job_resource_config=TrainingJobResourceConfig( - instance_type="ml.m4.xlarge", instance_count=1, volume_size_in_gb=25, + instance_type="ml.m4.xlarge", + instance_count=1, + volume_size_in_gb=25, ), algorithm_specification=AlgorithmSpecification( input_mode=InputMode.FILE, @@ -85,7 +87,8 @@ class SageMakerHPO(object): default=_HyperparameterTuningJobConfig( tuning_strategy=HyperparameterTuningStrategy.BAYESIAN, tuning_objective=HyperparameterTuningObjective( - objective_type=HyperparameterTuningObjectiveType.MINIMIZE, metric_name="validation:error", + objective_type=HyperparameterTuningObjectiveType.MINIMIZE, + metric_name="validation:error", ), training_job_early_stopping_type=TrainingJobEarlyStoppingType.AUTO, ), @@ -105,7 +108,10 @@ class SageMakerHPO(object): sagemaker_hpo_lp = SageMakerHPO.create_launch_plan() with _configuration.TemporaryConfiguration( - _os.path.join(_os.path.dirname(_os.path.realpath(__file__)), "../../common/configs/local.config",), + _os.path.join( + _os.path.dirname(_os.path.realpath(__file__)), + "../../common/configs/local.config", + ), internal_overrides={"image": "myflyteimage:v123", "project": "myflyteproject", "domain": "development"}, ): print("Printing WF definition") diff --git a/tests/flytekit/common/workflows/sidecar.py b/tests/flytekit/common/workflows/sidecar.py index aa01094cf6..ff0d4e7484 100644 --- a/tests/flytekit/common/workflows/sidecar.py +++ b/tests/flytekit/common/workflows/sidecar.py @@ -10,10 +10,16 @@ def generate_pod_spec_for_task(): pod_spec = generated_pb2.PodSpec() - secondary_container = generated_pb2.Container(name="secondary", image="alpine",) + secondary_container = generated_pb2.Container( + name="secondary", + image="alpine", + ) secondary_container.command.extend(["/bin/sh"]) secondary_container.args.extend(["-c", "echo hi sidecar world > /data/message.txt"]) - shared_volume_mount = generated_pb2.VolumeMount(name="shared-data", mountPath="/data",) + shared_volume_mount = generated_pb2.VolumeMount( + name="shared-data", + mountPath="/data", + ) secondary_container.volumeMounts.extend([shared_volume_mount]) primary_container = generated_pb2.Container(name="primary") @@ -23,7 +29,11 @@ def generate_pod_spec_for_task(): [ generated_pb2.Volume( name="shared-data", - volumeSource=generated_pb2.VolumeSource(emptyDir=generated_pb2.EmptyDirVolumeSource(medium="Memory",)), + volumeSource=generated_pb2.VolumeSource( + emptyDir=generated_pb2.EmptyDirVolumeSource( + medium="Memory", + ) + ), ) ] ) @@ -32,7 +42,8 @@ def generate_pod_spec_for_task(): @sidecar_task( - pod_spec=generate_pod_spec_for_task(), primary_container_name="primary", + pod_spec=generate_pod_spec_for_task(), + primary_container_name="primary", ) def a_sidecar_task(wfparams): while not os.path.isfile("/data/message.txt"): diff --git a/tests/flytekit/common/workflows/simple.py b/tests/flytekit/common/workflows/simple.py index 2e21a7f300..f264fe39bf 100644 --- a/tests/flytekit/common/workflows/simple.py +++ b/tests/flytekit/common/workflows/simple.py @@ -105,4 +105,10 @@ class SimpleWorkflow(object): c = subtract_one(a=input_1) d = write_special_types() - e = read_special_types(a=d.outputs.a, b=d.outputs.b, c=d.outputs.c, d=d.outputs.d, e=d.outputs.e,) + e = read_special_types( + a=d.outputs.a, + b=d.outputs.b, + c=d.outputs.c, + d=d.outputs.d, + e=d.outputs.e, + ) diff --git a/tests/flytekit/unit/bin/test_python_entrypoint.py b/tests/flytekit/unit/bin/test_python_entrypoint.py index 4813676fe1..bea2c9f333 100644 --- a/tests/flytekit/unit/bin/test_python_entrypoint.py +++ b/tests/flytekit/unit/bin/test_python_entrypoint.py @@ -30,7 +30,8 @@ def test_single_step_entrypoint_in_proc(): ): with _utils.AutoDeletingTempDir("in") as input_dir: literal_map = _type_helpers.pack_python_std_map_to_literal_map( - {"a": 9}, _type_map_from_variable_map(_task_defs.add_one.interface.inputs), + {"a": 9}, + _type_map_from_variable_map(_task_defs.add_one.interface.inputs), ) input_file = os.path.join(input_dir.name, "inputs.pb") _utils.write_proto_to_file(literal_map.to_flyte_idl(), input_file) @@ -46,7 +47,8 @@ def test_single_step_entrypoint_in_proc(): ) p = _utils.load_proto_from_file( - _literals_pb2.LiteralMap, os.path.join(output_dir.name, _constants.OUTPUT_FILE_NAME), + _literals_pb2.LiteralMap, + os.path.join(output_dir.name, _constants.OUTPUT_FILE_NAME), ) raw_map = _type_helpers.unpack_literal_map_to_sdk_python_std( _literal_models.LiteralMap.from_flyte_idl(p), @@ -63,7 +65,8 @@ def test_single_step_entrypoint_out_of_proc(): ): with _utils.AutoDeletingTempDir("in") as input_dir: literal_map = _type_helpers.pack_python_std_map_to_literal_map( - {"a": 9}, _type_map_from_variable_map(_task_defs.add_one.interface.inputs), + {"a": 9}, + _type_map_from_variable_map(_task_defs.add_one.interface.inputs), ) input_file = os.path.join(input_dir.name, "inputs.pb") _utils.write_proto_to_file(literal_map.to_flyte_idl(), input_file) @@ -78,7 +81,8 @@ def test_single_step_entrypoint_out_of_proc(): assert result.exit_code == 0 p = _utils.load_proto_from_file( - _literals_pb2.LiteralMap, os.path.join(output_dir.name, _constants.OUTPUT_FILE_NAME), + _literals_pb2.LiteralMap, + os.path.join(output_dir.name, _constants.OUTPUT_FILE_NAME), ) raw_map = _type_helpers.unpack_literal_map_to_sdk_python_std( _literal_models.LiteralMap.from_flyte_idl(p), @@ -95,7 +99,8 @@ def test_arrayjob_entrypoint_in_proc(): ): with _utils.AutoDeletingTempDir("dir") as dir: literal_map = _type_helpers.pack_python_std_map_to_literal_map( - {"a": 9}, _type_map_from_variable_map(_task_defs.add_one.interface.inputs), + {"a": 9}, + _type_map_from_variable_map(_task_defs.add_one.interface.inputs), ) input_dir = os.path.join(dir.name, "1") @@ -128,7 +133,8 @@ def test_arrayjob_entrypoint_in_proc(): raw_map = _type_helpers.unpack_literal_map_to_sdk_python_std( _literal_models.LiteralMap.from_flyte_idl( _utils.load_proto_from_file( - _literals_pb2.LiteralMap, os.path.join(input_dir, _constants.OUTPUT_FILE_NAME), + _literals_pb2.LiteralMap, + os.path.join(input_dir, _constants.OUTPUT_FILE_NAME), ) ), _type_map_from_variable_map(_task_defs.add_one.interface.outputs), diff --git a/tests/flytekit/unit/cli/auth/test_discovery.py b/tests/flytekit/unit/cli/auth/test_discovery.py index 5813d18bf0..c75427f35d 100644 --- a/tests/flytekit/unit/cli/auth/test_discovery.py +++ b/tests/flytekit/unit/cli/auth/test_discovery.py @@ -11,7 +11,9 @@ def test_get_authorization_endpoints(): auth_endpoint = "http://flyte-admin.com/authorization" token_endpoint = "http://flyte-admin.com/token" responses.add( - responses.GET, discovery_url, json={"authorization_endpoint": auth_endpoint, "token_endpoint": token_endpoint}, + responses.GET, + discovery_url, + json={"authorization_endpoint": auth_endpoint, "token_endpoint": token_endpoint}, ) discovery_client = _discovery.DiscoveryClient(discovery_url=discovery_url) @@ -26,7 +28,9 @@ def test_get_authorization_endpoints_relative(): auth_endpoint = "/authorization" token_endpoint = "/token" responses.add( - responses.GET, discovery_url, json={"authorization_endpoint": auth_endpoint, "token_endpoint": token_endpoint}, + responses.GET, + discovery_url, + json={"authorization_endpoint": auth_endpoint, "token_endpoint": token_endpoint}, ) discovery_client = _discovery.DiscoveryClient(discovery_url=discovery_url) @@ -38,7 +42,9 @@ def test_get_authorization_endpoints_relative(): def test_get_authorization_endpoints_missing_authorization_endpoint(): discovery_url = "http://flyte-admin.com/discovery" responses.add( - responses.GET, discovery_url, json={"token_endpoint": "http://flyte-admin.com/token"}, + responses.GET, + discovery_url, + json={"token_endpoint": "http://flyte-admin.com/token"}, ) discovery_client = _discovery.DiscoveryClient(discovery_url=discovery_url) @@ -50,7 +56,9 @@ def test_get_authorization_endpoints_missing_authorization_endpoint(): def test_get_authorization_endpoints_missing_token_endpoint(): discovery_url = "http://flyte-admin.com/discovery" responses.add( - responses.GET, discovery_url, json={"authorization_endpoint": "http://flyte-admin.com/authorization"}, + responses.GET, + discovery_url, + json={"authorization_endpoint": "http://flyte-admin.com/authorization"}, ) discovery_client = _discovery.DiscoveryClient(discovery_url=discovery_url) diff --git a/tests/flytekit/unit/cli/pyflyte/conftest.py b/tests/flytekit/unit/cli/pyflyte/conftest.py index a829db158d..723fb4878b 100644 --- a/tests/flytekit/unit/cli/pyflyte/conftest.py +++ b/tests/flytekit/unit/cli/pyflyte/conftest.py @@ -21,7 +21,10 @@ def _fake_module_load(names): @pytest.yield_fixture( scope="function", params=[ - os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../../common/configs/local.config",), + os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "../../../common/configs/local.config", + ), "/foo/bar", None, ], diff --git a/tests/flytekit/unit/cli/test_cli_helpers.py b/tests/flytekit/unit/cli/test_cli_helpers.py index 8a6c6e0263..9188b2fc3a 100644 --- a/tests/flytekit/unit/cli/test_cli_helpers.py +++ b/tests/flytekit/unit/cli/test_cli_helpers.py @@ -111,7 +111,8 @@ def test_hydrate_workflow_template(): id="launchplan_ref", workflow_node=_core_workflow_pb2.WorkflowNode( launchplan_ref=_identifier_pb2.Identifier( - resource_type=_identifier_pb2.LAUNCH_PLAN, project="project2", + resource_type=_identifier_pb2.LAUNCH_PLAN, + project="project2", ) ), ) @@ -121,7 +122,9 @@ def test_hydrate_workflow_template(): id="sub_workflow_ref", workflow_node=_core_workflow_pb2.WorkflowNode( sub_workflow_ref=_identifier_pb2.Identifier( - resource_type=_identifier_pb2.WORKFLOW, project="project2", domain="domain2", + resource_type=_identifier_pb2.WORKFLOW, + project="project2", + domain="domain2", ) ), ) @@ -245,7 +248,13 @@ def test_hydrate_registration_parameters__launch_plan_already_set(): ) ), ) - identifier, entity = hydrate_registration_parameters(LAUNCH_PLAN, "project", "domain", "12345", launch_plan,) + identifier, entity = hydrate_registration_parameters( + LAUNCH_PLAN, + "project", + "domain", + "12345", + launch_plan, + ) assert identifier == _identifier_pb2.Identifier( resource_type=_identifier_pb2.LAUNCH_PLAN, project="project2", @@ -258,16 +267,30 @@ def test_hydrate_registration_parameters__launch_plan_already_set(): def test_hydrate_registration_parameters__launch_plan_nothing_set(): launch_plan = _launch_plan_pb2.LaunchPlan( - id=_identifier_pb2.Identifier(resource_type=_identifier_pb2.LAUNCH_PLAN, name="lp_name",), + id=_identifier_pb2.Identifier( + resource_type=_identifier_pb2.LAUNCH_PLAN, + name="lp_name", + ), spec=_launch_plan_pb2.LaunchPlanSpec( - workflow_id=_identifier_pb2.Identifier(resource_type=_identifier_pb2.WORKFLOW, name="workflow_name",) + workflow_id=_identifier_pb2.Identifier( + resource_type=_identifier_pb2.WORKFLOW, + name="workflow_name", + ) ), ) identifier, entity = hydrate_registration_parameters( - _identifier_pb2.LAUNCH_PLAN, "project", "domain", "12345", launch_plan, + _identifier_pb2.LAUNCH_PLAN, + "project", + "domain", + "12345", + launch_plan, ) assert identifier == _identifier_pb2.Identifier( - resource_type=_identifier_pb2.LAUNCH_PLAN, project="project", domain="domain", name="lp_name", version="12345", + resource_type=_identifier_pb2.LAUNCH_PLAN, + project="project", + domain="domain", + name="lp_name", + version="12345", ) assert entity.spec.workflow_id == _identifier_pb2.Identifier( resource_type=_identifier_pb2.WORKFLOW, @@ -282,7 +305,11 @@ def test_hydrate_registration_parameters__task_already_set(): task = _task_pb2.TaskSpec( template=_core_task_pb2.TaskTemplate( id=_identifier_pb2.Identifier( - resource_type=_identifier_pb2.TASK, project="project2", domain="domain2", name="name", version="abc", + resource_type=_identifier_pb2.TASK, + project="project2", + domain="domain2", + name="name", + version="abc", ), ) ) @@ -290,7 +317,11 @@ def test_hydrate_registration_parameters__task_already_set(): assert ( identifier == _identifier_pb2.Identifier( - resource_type=_identifier_pb2.TASK, project="project2", domain="domain2", name="name", version="abc", + resource_type=_identifier_pb2.TASK, + project="project2", + domain="domain2", + name="name", + version="abc", ) == entity.template.id ) @@ -299,14 +330,21 @@ def test_hydrate_registration_parameters__task_already_set(): def test_hydrate_registration_parameters__task_nothing_set(): task = _task_pb2.TaskSpec( template=_core_task_pb2.TaskTemplate( - id=_identifier_pb2.Identifier(resource_type=_identifier_pb2.TASK, name="name",), + id=_identifier_pb2.Identifier( + resource_type=_identifier_pb2.TASK, + name="name", + ), ) ) identifier, entity = hydrate_registration_parameters(_identifier_pb2.TASK, "project", "domain", "12345", task) assert ( identifier == _identifier_pb2.Identifier( - resource_type=_identifier_pb2.TASK, project="project", domain="domain", name="name", version="12345", + resource_type=_identifier_pb2.TASK, + project="project", + domain="domain", + name="name", + version="12345", ) == entity.template.id ) @@ -330,7 +368,11 @@ def test_hydrate_registration_parameters__workflow_already_set(): assert ( identifier == _identifier_pb2.Identifier( - resource_type=_identifier_pb2.WORKFLOW, project="project2", domain="domain2", name="name", version="abc", + resource_type=_identifier_pb2.WORKFLOW, + project="project2", + domain="domain2", + name="name", + version="abc", ) == entity.template.id ) @@ -339,7 +381,10 @@ def test_hydrate_registration_parameters__workflow_already_set(): def test_hydrate_registration_parameters__workflow_nothing_set(): workflow = _workflow_pb2.WorkflowSpec( template=_core_workflow_pb2.WorkflowTemplate( - id=_identifier_pb2.Identifier(resource_type=_identifier_pb2.WORKFLOW, name="name",), + id=_identifier_pb2.Identifier( + resource_type=_identifier_pb2.WORKFLOW, + name="name", + ), nodes=[ _core_workflow_pb2.Node( id="foo", @@ -356,13 +401,21 @@ def test_hydrate_registration_parameters__workflow_nothing_set(): assert ( identifier == _identifier_pb2.Identifier( - resource_type=_identifier_pb2.WORKFLOW, project="project", domain="domain", name="name", version="12345", + resource_type=_identifier_pb2.WORKFLOW, + project="project", + domain="domain", + name="name", + version="12345", ) == entity.template.id ) assert len(workflow.template.nodes) == 1 assert workflow.template.nodes[0].task_node.reference_id == _identifier_pb2.Identifier( - resource_type=_identifier_pb2.TASK, project="project", domain="domain", name="task1", version="12345", + resource_type=_identifier_pb2.TASK, + project="project", + domain="domain", + name="task1", + version="12345", ) @@ -401,5 +454,9 @@ def test_hydrate_registration_parameters__subworkflows(): ) assert entity.sub_workflows[0].id == _identifier_pb2.Identifier( - resource_type=_identifier_pb2.WORKFLOW, project="project", domain="domain", name="subworkflow", version="12345", + resource_type=_identifier_pb2.WORKFLOW, + project="project", + domain="domain", + name="subworkflow", + version="12345", ) diff --git a/tests/flytekit/unit/cli/test_flyte_cli.py b/tests/flytekit/unit/cli/test_flyte_cli.py index da63efa1dc..0fe7e54555 100644 --- a/tests/flytekit/unit/cli/test_flyte_cli.py +++ b/tests/flytekit/unit/cli/test_flyte_cli.py @@ -34,7 +34,8 @@ def my_task(wf_params, a, b): def test__extract_files(load_mock): t = get_sample_task() with TemporaryConfiguration( - "", internal_overrides={"image": "myflyteimage:v123", "project": "myflyteproject", "domain": "development"}, + "", + internal_overrides={"image": "myflyteimage:v123", "project": "myflyteproject", "domain": "development"}, ): task_spec = t.serialize() @@ -52,7 +53,11 @@ def test__extract_files(load_mock): @_mock.patch("flytekit.clis.flyte_cli.main._load_proto_from_file") def test__extract_files_with_unspecified_resource_type(load_mock): id = _core_identifier.Identifier( - _core_identifier.ResourceType.UNSPECIFIED, "myproject", "development", "name", "v", + _core_identifier.ResourceType.UNSPECIFIED, + "myproject", + "development", + "name", + "v", ) load_mock.return_value = id.to_flyte_idl() diff --git a/tests/flytekit/unit/common_tests/exceptions/test_system.py b/tests/flytekit/unit/common_tests/exceptions/test_system.py index 4610703efa..d53ed00f6c 100644 --- a/tests/flytekit/unit/common_tests/exceptions/test_system.py +++ b/tests/flytekit/unit/common_tests/exceptions/test_system.py @@ -48,7 +48,9 @@ def test_flyte_entrypoint_not_loadable_exception(): try: raise system.FlyteEntrypointNotLoadable( - "fake.module", task_name="secret_task", additional_msg="Shouldn't have used a fake module!", + "fake.module", + task_name="secret_task", + additional_msg="Shouldn't have used a fake module!", ) except Exception as e: assert ( diff --git a/tests/flytekit/unit/common_tests/exceptions/test_user.py b/tests/flytekit/unit/common_tests/exceptions/test_user.py index f8851bb122..e3b3fbd319 100644 --- a/tests/flytekit/unit/common_tests/exceptions/test_user.py +++ b/tests/flytekit/unit/common_tests/exceptions/test_user.py @@ -22,7 +22,10 @@ def test_flyte_type_exception(): try: raise user.FlyteTypeException( - "int", ("list", "set"), received_value=1, additional_msg="That was a bad idea!", + "int", + ("list", "set"), + received_value=1, + additional_msg="That was a bad idea!", ) except Exception as e: assert ( diff --git a/tests/flytekit/unit/common_tests/tasks/test_execution_params.py b/tests/flytekit/unit/common_tests/tasks/test_execution_params.py index 76509d54f0..cf634bfa96 100644 --- a/tests/flytekit/unit/common_tests/tasks/test_execution_params.py +++ b/tests/flytekit/unit/common_tests/tasks/test_execution_params.py @@ -29,7 +29,9 @@ def test_secrets_manager_get_file(): with pytest.raises(ValueError): sec.get_secrets_file("", "x") assert sec.get_secrets_file("group", "test") == os.path.join( - secrets.SECRETS_DEFAULT_DIR.get(), "group", f"{secrets.SECRETS_FILE_PREFIX.get()}test", + secrets.SECRETS_DEFAULT_DIR.get(), + "group", + f"{secrets.SECRETS_FILE_PREFIX.get()}test", ) diff --git a/tests/flytekit/unit/common_tests/tasks/test_task.py b/tests/flytekit/unit/common_tests/tasks/test_task.py index 9212211b0e..e33a757412 100644 --- a/tests/flytekit/unit/common_tests/tasks/test_task.py +++ b/tests/flytekit/unit/common_tests/tasks/test_task.py @@ -21,7 +21,8 @@ def test_fetch_latest(mock_url, mock_client_manager): mock_url.get.return_value = "localhost" admin_task = _task_models.Task( - _identifier.Identifier(_identifier.ResourceType.TASK, "p1", "d1", "n1", "v1"), _MagicMock(), + _identifier.Identifier(_identifier.ResourceType.TASK, "p1", "d1", "n1", "v1"), + _MagicMock(), ) mock_client = _MagicMock() mock_client.list_tasks_paginated = _MagicMock(return_value=([admin_task], "")) @@ -58,7 +59,10 @@ def my_task(wf_params, a, b): def test_task_serialization(): t = get_sample_task() with TemporaryConfiguration( - _os.path.join(_os.path.dirname(_os.path.realpath(__file__)), "../../../common/configs/local.config",), + _os.path.join( + _os.path.dirname(_os.path.realpath(__file__)), + "../../../common/configs/local.config", + ), internal_overrides={"image": "myflyteimage:v123", "project": "myflyteproject", "domain": "development"}, ): s = t.serialize() diff --git a/tests/flytekit/unit/common_tests/test_interface.py b/tests/flytekit/unit/common_tests/test_interface.py index e3eb65f081..b6627c1a6b 100644 --- a/tests/flytekit/unit/common_tests/test_interface.py +++ b/tests/flytekit/unit/common_tests/test_interface.py @@ -21,19 +21,23 @@ def test_binding_data_primitive_static(): with pytest.raises(_user_exceptions.FlyteTypeException): interface.BindingData.from_python_std( - primitives.Float.to_flyte_literal_type(), "abc", + primitives.Float.to_flyte_literal_type(), + "abc", ) with pytest.raises(_user_exceptions.FlyteTypeException): interface.BindingData.from_python_std( - primitives.Float.to_flyte_literal_type(), [1.0, 2.0, 3.0], + primitives.Float.to_flyte_literal_type(), + [1.0, 2.0, 3.0], ) def test_binding_data_list_static(): upstream_nodes = set() bd = interface.BindingData.from_python_std( - containers.List(primitives.String).to_flyte_literal_type(), ["abc", "cde"], upstream_nodes=upstream_nodes, + containers.List(primitives.String).to_flyte_literal_type(), + ["abc", "cde"], + upstream_nodes=upstream_nodes, ) assert len(upstream_nodes) == 0 @@ -47,7 +51,8 @@ def test_binding_data_list_static(): with pytest.raises(_user_exceptions.FlyteTypeException): interface.BindingData.from_python_std( - containers.List(primitives.String).to_flyte_literal_type(), "abc", + containers.List(primitives.String).to_flyte_literal_type(), + "abc", ) with pytest.raises(_user_exceptions.FlyteTypeException): diff --git a/tests/flytekit/unit/common_tests/test_launch_plan.py b/tests/flytekit/unit/common_tests/test_launch_plan.py index 495aa0d3c4..92943d37f1 100644 --- a/tests/flytekit/unit/common_tests/test_launch_plan.py +++ b/tests/flytekit/unit/common_tests/test_launch_plan.py @@ -18,7 +18,10 @@ def test_default_assumable_iam_role(): with _configuration.TemporaryConfiguration( - _os.path.join(_os.path.dirname(_os.path.realpath(__file__)), "../../common/configs/local.config",) + _os.path.join( + _os.path.dirname(_os.path.realpath(__file__)), + "../../common/configs/local.config", + ) ): workflow_to_test = _workflow.workflow( {}, @@ -45,7 +48,10 @@ def test_hard_coded_assumable_iam_role(): def test_default_deprecated_role(): with _configuration.TemporaryConfiguration( - _os.path.join(_os.path.dirname(_os.path.realpath(__file__)), "../../common/configs/deprecated_local.config",) + _os.path.join( + _os.path.dirname(_os.path.realpath(__file__)), + "../../common/configs/deprecated_local.config", + ) ): workflow_to_test = _workflow.workflow( {}, @@ -150,7 +156,11 @@ def test_schedule(schedule, cron_expression, cron_schedule): "default_input": _workflow.Input(_types.Types.Integer, default=5), }, ) - lp = workflow_to_test.create_launch_plan(fixed_inputs={"required_input": 5}, schedule=schedule, role="what",) + lp = workflow_to_test.create_launch_plan( + fixed_inputs={"required_input": 5}, + schedule=schedule, + role="what", + ) assert lp.entity_metadata.schedule.kickoff_time_input_arg is None assert lp.entity_metadata.schedule.cron_expression == cron_expression assert lp.entity_metadata.schedule.cron_schedule == cron_schedule @@ -180,7 +190,8 @@ def test_schedule_pointing_to_datetime(): }, ) lp = workflow_to_test.create_launch_plan( - schedule=_schedules.CronSchedule("* * ? * * *", kickoff_time_input_arg="required_input"), role="what", + schedule=_schedules.CronSchedule("* * ? * * *", kickoff_time_input_arg="required_input"), + role="what", ) assert lp.entity_metadata.schedule.kickoff_time_input_arg == "required_input" assert lp.entity_metadata.schedule.cron_expression == "* * ? * * *" @@ -310,10 +321,16 @@ def test_serialize(): }, ) workflow_to_test.id = _identifier.Identifier(_identifier.ResourceType.WORKFLOW, "p", "d", "n", "v") - lp = workflow_to_test.create_launch_plan(fixed_inputs={"required_input": 5}, role="iam_role",) + lp = workflow_to_test.create_launch_plan( + fixed_inputs={"required_input": 5}, + role="iam_role", + ) with _configuration.TemporaryConfiguration( - _os.path.join(_os.path.dirname(_os.path.realpath(__file__)), "../../common/configs/local.config",), + _os.path.join( + _os.path.dirname(_os.path.realpath(__file__)), + "../../common/configs/local.config", + ), internal_overrides={"image": "myflyteimage:v123", "project": "myflyteproject", "domain": "development"}, ): s = lp.serialize() @@ -360,9 +377,12 @@ def test_raw_data_output_prefix(): }, ) lp = workflow_to_test.create_launch_plan( - fixed_inputs={"required_input": 5}, raw_output_data_prefix="s3://bucket-name", + fixed_inputs={"required_input": 5}, + raw_output_data_prefix="s3://bucket-name", ) assert lp.raw_output_data_config.output_location_prefix == "s3://bucket-name" - lp2 = workflow_to_test.create_launch_plan(fixed_inputs={"required_input": 5},) + lp2 = workflow_to_test.create_launch_plan( + fixed_inputs={"required_input": 5}, + ) assert lp2.raw_output_data_config.output_location_prefix == "" diff --git a/tests/flytekit/unit/common_tests/test_nodes.py b/tests/flytekit/unit/common_tests/test_nodes.py index 7d2b97e81c..8dd90097c6 100644 --- a/tests/flytekit/unit/common_tests/test_nodes.py +++ b/tests/flytekit/unit/common_tests/test_nodes.py @@ -26,7 +26,8 @@ def testy_test(wf_params, a, b): [], [ _literals.Binding( - "a", _interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), 3), + "a", + _interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), 3), ) ], _core_workflow_models.NodeMetadata("abc", _datetime.timedelta(minutes=15), _literals.RetryStrategy(3)), @@ -58,7 +59,8 @@ def testy_test(wf_params, a, b): [n], [ _literals.Binding( - "a", _interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), n.outputs.b), + "a", + _interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), n.outputs.b), ) ], _core_workflow_models.NodeMetadata("abc2", _datetime.timedelta(minutes=15), _literals.RetryStrategy(3)), @@ -93,7 +95,8 @@ def testy_test(wf_params, a, b): [], [ _literals.Binding( - "a", _interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), 3), + "a", + _interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), 3), ) ], _core_workflow_models.NodeMetadata("abc3", _datetime.timedelta(minutes=15), _literals.RetryStrategy(3)), @@ -131,7 +134,8 @@ def testy_test(wf_params, a, b): [], [ _literals.Binding( - "a", _interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), 3), + "a", + _interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), 3), ) ], _core_workflow_models.NodeMetadata("abc4", _datetime.timedelta(minutes=15), _literals.RetryStrategy(3)), @@ -191,7 +195,11 @@ def testy_test(wf_params, a, b): # Test floating ID testy_test._id = _identifier.Identifier( - _identifier.ResourceType.TASK, "new_project", "new_domain", "new_name", "new_version", + _identifier.ResourceType.TASK, + "new_project", + "new_domain", + "new_name", + "new_version", ) assert n.reference_id.project == "new_project" assert n.reference_id.domain == "new_domain" @@ -219,7 +227,8 @@ class test_workflow(object): [], [ _literals.Binding( - "a", _interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), 3), + "a", + _interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), 3), ) ], _core_workflow_models.NodeMetadata("abc", _datetime.timedelta(minutes=15), _literals.RetryStrategy(3)), @@ -267,7 +276,11 @@ class test_workflow(object): # Test floating ID lp._id = _identifier.Identifier( - _identifier.ResourceType.TASK, "new_project", "new_domain", "new_name", "new_version", + _identifier.ResourceType.TASK, + "new_project", + "new_domain", + "new_name", + "new_version", ) assert n.launchplan_ref.project == "new_project" assert n.launchplan_ref.domain == "new_domain" diff --git a/tests/flytekit/unit/common_tests/test_schedules.py b/tests/flytekit/unit/common_tests/test_schedules.py index 0009b424dc..4e6f231307 100644 --- a/tests/flytekit/unit/common_tests/test_schedules.py +++ b/tests/flytekit/unit/common_tests/test_schedules.py @@ -92,7 +92,8 @@ def test_cron_schedule_schedule_validation(schedule): @_pytest.mark.parametrize( - "schedule", ["foo", "* *"], + "schedule", + ["foo", "* *"], ) def test_cron_schedule_schedule_validation_invalid(schedule): with _pytest.raises(_user_exceptions.FlyteAssertion): diff --git a/tests/flytekit/unit/common_tests/test_translator.py b/tests/flytekit/unit/common_tests/test_translator.py index 07a9437826..70927fbfdf 100644 --- a/tests/flytekit/unit/common_tests/test_translator.py +++ b/tests/flytekit/unit/common_tests/test_translator.py @@ -58,7 +58,10 @@ def my_wf(a: int, b: str) -> (int, str): sdk_task = get_serializable(OrderedDict(), serialization_settings, t1, True) assert "pyflyte-execute" in sdk_task.container.args - lp = LaunchPlan.create("testlp", my_wf,) + lp = LaunchPlan.create( + "testlp", + my_wf, + ) sdk_lp = get_serializable(OrderedDict(), serialization_settings, lp) assert sdk_lp.id.name == "testlp" diff --git a/tests/flytekit/unit/common_tests/test_workflow.py b/tests/flytekit/unit/common_tests/test_workflow.py index ae1ef6df34..13a9e65d8e 100644 --- a/tests/flytekit/unit/common_tests/test_workflow.py +++ b/tests/flytekit/unit/common_tests/test_workflow.py @@ -137,7 +137,8 @@ class my_workflow(object): a = _local_workflow.Output("a", n1.outputs.b, sdk_type=primitives.Integer) w = _local_workflow.build_sdk_workflow_from_metaclass( - my_workflow, on_failure=_workflow_models.WorkflowMetadata.OnFailurePolicy.FAIL_AFTER_EXECUTABLE_NODES_COMPLETE, + my_workflow, + on_failure=_workflow_models.WorkflowMetadata.OnFailurePolicy.FAIL_AFTER_EXECUTABLE_NODES_COMPLETE, ) assert w.should_create_default_launch_plan is True @@ -223,7 +224,9 @@ def my_list_task(wf_params, a, b): wf_out = [ _local_workflow.Output( - "nested_out", [n5.outputs.b, n6.outputs.b, [n1.outputs.b, n2.outputs.b]], sdk_type=[[primitives.Integer]], + "nested_out", + [n5.outputs.b, n6.outputs.b, [n1.outputs.b, n2.outputs.b]], + sdk_type=[[primitives.Integer]], ), _local_workflow.Output("scalar_out", n1.outputs.b, sdk_type=primitives.Integer), ] @@ -299,7 +302,8 @@ def my_task(wf_params, a, b): [], [ _literals.Binding( - "a", interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), 3), + "a", + interface.BindingData.from_python_std(_types.Types.Integer.to_flyte_literal_type(), 3), ) ], None, @@ -349,7 +353,9 @@ def my_list_task(wf_params, a, b): wf_out = [ _local_workflow.Output( - "nested_out", [n5.outputs.b, n6.outputs.b, [n1.outputs.b, n2.outputs.b]], sdk_type=[[primitives.Integer]], + "nested_out", + [n5.outputs.b, n6.outputs.b, [n1.outputs.b, n2.outputs.b]], + sdk_type=[[primitives.Integer]], ), _local_workflow.Output("scalar_out", n1.outputs.b, sdk_type=primitives.Integer), ] @@ -369,6 +375,9 @@ class MyWorkflow(object): input_1 = promise.Input("input_1", primitives.Integer) input_2 = promise.Input("input_2", primitives.Integer, default=5, help="Not required.") - w = build_sdk_workflow_from_metaclass(MyWorkflow, disable_default_launch_plan=True,) + w = build_sdk_workflow_from_metaclass( + MyWorkflow, + disable_default_launch_plan=True, + ) assert w.should_create_default_launch_plan is False diff --git a/tests/flytekit/unit/common_tests/test_workflow_promote.py b/tests/flytekit/unit/common_tests/test_workflow_promote.py index bbfd5e627e..cdd1231107 100644 --- a/tests/flytekit/unit/common_tests/test_workflow_promote.py +++ b/tests/flytekit/unit/common_tests/test_workflow_promote.py @@ -39,7 +39,12 @@ def get_sample_container(): resources = _task_model.Resources(requests=[cpu_resource], limits=[cpu_resource]) return _task_model.Container( - "my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], resources, {}, {}, + "my_image", + ["this", "is", "a", "cmd"], + ["this", "is", "an", "arg"], + resources, + {}, + {}, ) diff --git a/tests/flytekit/unit/common_tests/types/impl/test_schema.py b/tests/flytekit/unit/common_tests/types/impl/test_schema.py index 90733564d7..de5bba4a46 100644 --- a/tests/flytekit/unit/common_tests/types/impl/test_schema.py +++ b/tests/flytekit/unit/common_tests/types/impl/test_schema.py @@ -164,7 +164,9 @@ def test_fetch(value_type_pair): with _utils.AutoDeletingTempDir("test2") as local_dir: schema_obj = _schema_impl.Schema.fetch( - tmpdir.name, local_path=local_dir.get_named_tempfile("schema_test"), schema_type=schema_type, + tmpdir.name, + local_path=local_dir.get_named_tempfile("schema_test"), + schema_type=schema_type, ) with schema_obj as reader: for df in reader.iter_chunks(): @@ -283,7 +285,8 @@ def uuid4(self): ) SET LOCATION 's3://my_fixed_path/'; """ query = df.get_write_partition_to_hive_table_query( - "some_table", partitions=_collections.OrderedDict([("region", "SEA"), ("ds", "2017-01-01")]), + "some_table", + partitions=_collections.OrderedDict([("region", "SEA"), ("ds", "2017-01-01")]), ) full_query = " ".join(full_query.split()) query = " ".join(query.split()) @@ -299,7 +302,8 @@ def test_partial_column_read(): writer.write(_pd.DataFrame.from_dict({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]})) b = _schema_impl.Schema.fetch( - a.uri, schema_type=_schema_impl.SchemaType([("a", _primitives.Integer), ("b", _primitives.Integer)]), + a.uri, + schema_type=_schema_impl.SchemaType([("a", _primitives.Integer), ("b", _primitives.Integer)]), ) with b as reader: df = reader.read(columns=["b"]) @@ -322,7 +326,8 @@ def single_dataframe(): ) assert s is not None n = _schema_impl.Schema.fetch( - s.uri, schema_type=_schema_impl.SchemaType([("a", _primitives.Integer), ("b", _primitives.Integer)]), + s.uri, + schema_type=_schema_impl.SchemaType([("a", _primitives.Integer), ("b", _primitives.Integer)]), ) with n as reader: df2 = reader.read() @@ -338,7 +343,8 @@ def list_of_dataframes(): ) assert s is not None n = _schema_impl.Schema.fetch( - s.uri, schema_type=_schema_impl.SchemaType([("a", _primitives.Integer), ("b", _primitives.Integer)]), + s.uri, + schema_type=_schema_impl.SchemaType([("a", _primitives.Integer), ("b", _primitives.Integer)]), ) with n as reader: actual = [] @@ -365,7 +371,8 @@ def empty_list(): ) assert s is not None n = _schema_impl.Schema.fetch( - s.uri, schema_type=_schema_impl.SchemaType([("a", _primitives.Integer), ("b", _primitives.Integer)]), + s.uri, + schema_type=_schema_impl.SchemaType([("a", _primitives.Integer), ("b", _primitives.Integer)]), ) with n as reader: df = reader.read() @@ -474,7 +481,8 @@ def test_extra_schema_read(): writer.write(_pd.DataFrame.from_dict({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]})) b = _schema_impl.Schema.fetch( - a.remote_prefix, schema_type=_schema_impl.SchemaType([("a", _primitives.Integer)]), + a.remote_prefix, + schema_type=_schema_impl.SchemaType([("a", _primitives.Integer)]), ) with b as reader: df = reader.read(concat=True, truncate_extra_columns=False) diff --git a/tests/flytekit/unit/common_tests/types/test_blobs.py b/tests/flytekit/unit/common_tests/types/test_blobs.py index 4fdee08b3b..3057b11c6a 100644 --- a/tests/flytekit/unit/common_tests/types/test_blobs.py +++ b/tests/flytekit/unit/common_tests/types/test_blobs.py @@ -37,7 +37,10 @@ def test_blob_promote_from_model(): scalar=_literal_models.Scalar( blob=_literal_models.Blob( _literal_models.BlobMetadata( - _core_types.BlobType(format="f", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE,) + _core_types.BlobType( + format="f", + dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, + ) ), "some/path", ) diff --git a/tests/flytekit/unit/common_tests/types/test_helpers.py b/tests/flytekit/unit/common_tests/types/test_helpers.py index 425feef30f..dd8b45af23 100644 --- a/tests/flytekit/unit/common_tests/types/test_helpers.py +++ b/tests/flytekit/unit/common_tests/types/test_helpers.py @@ -35,7 +35,8 @@ def test_get_sdk_value_from_literal(): assert o.to_python_std() is None o = _type_helpers.get_sdk_value_from_literal( - _literals.Literal(scalar=_literals.Scalar(none_type=_literals.Void())), sdk_type=_sdk_types.Types.Integer, + _literals.Literal(scalar=_literals.Scalar(none_type=_literals.Void())), + sdk_type=_sdk_types.Types.Integer, ) assert o.to_python_std() is None diff --git a/tests/flytekit/unit/configuration/test_temporary_configuration.py b/tests/flytekit/unit/configuration/test_temporary_configuration.py index a41f131ec8..fe51f46d06 100644 --- a/tests/flytekit/unit/configuration/test_temporary_configuration.py +++ b/tests/flytekit/unit/configuration/test_temporary_configuration.py @@ -13,7 +13,8 @@ def test_configuration_file(): def test_internal_overrides(): with _TemporaryConfiguration( - _os.path.join(_os.path.dirname(_os.path.realpath(__file__)), "configs/good.config"), {"foo": "bar"}, + _os.path.join(_os.path.dirname(_os.path.realpath(__file__)), "configs/good.config"), + {"foo": "bar"}, ): assert _os.environ.get("FLYTE_INTERNAL_FOO") == "bar" assert _os.environ.get("FLYTE_INTERNAL_FOO") is None diff --git a/tests/flytekit/unit/contrib/sensors/test_impl.py b/tests/flytekit/unit/contrib/sensors/test_impl.py index e14d4f75da..8382dae762 100644 --- a/tests/flytekit/unit/contrib/sensors/test_impl.py +++ b/tests/flytekit/unit/contrib/sensors/test_impl.py @@ -32,7 +32,9 @@ def test_HiveNamedPartitionSensor(): assert interval is None with mock.patch.object( - HMSClient, "get_partition_by_name", side_effect=_ttypes.NoSuchObjectException(), + HMSClient, + "get_partition_by_name", + side_effect=_ttypes.NoSuchObjectException(), ): success, interval = hive_named_partition_sensor._do_poll() assert not success diff --git a/tests/flytekit/unit/core/test_references.py b/tests/flytekit/unit/core/test_references.py index 0fb66c7235..56d9bd090c 100644 --- a/tests/flytekit/unit/core/test_references.py +++ b/tests/flytekit/unit/core/test_references.py @@ -117,7 +117,11 @@ def inner_test(ref_mock): def test_ref_plain_no_outputs(): - r1 = ReferenceEntity(TaskReference("proj", "domain", "some.name", "abc"), inputs=kwtypes(a=str, b=int), outputs={},) + r1 = ReferenceEntity( + TaskReference("proj", "domain", "some.name", "abc"), + inputs=kwtypes(a=str, b=int), + outputs={}, + ) # Reference entities should always raise an exception when not mocked out. with pytest.raises(Exception) as e: @@ -207,7 +211,13 @@ def inner_test(ref_mock): ) def test_lps(resource_type): ref_entity = get_reference_entity( - resource_type, "proj", "dom", "app.other.flyte_entity", "123", inputs=kwtypes(a=str, b=int), outputs={}, + resource_type, + "proj", + "dom", + "app.other.flyte_entity", + "123", + inputs=kwtypes(a=str, b=int), + outputs={}, ) ctx = context_manager.FlyteContext.current_context() diff --git a/tests/flytekit/unit/core/test_schedule.py b/tests/flytekit/unit/core/test_schedule.py index 8564064e05..bd76a24405 100644 --- a/tests/flytekit/unit/core/test_schedule.py +++ b/tests/flytekit/unit/core/test_schedule.py @@ -95,7 +95,8 @@ def test_cron_schedule_schedule_validation(schedule): @_pytest.mark.parametrize( - "schedule", ["foo", "* *"], + "schedule", + ["foo", "* *"], ) def test_cron_schedule_schedule_validation_invalid(schedule): with _pytest.raises(ValueError): @@ -143,7 +144,9 @@ def quadruple(a: int) -> int: return c lp = LaunchPlan.create( - "schedule_test", quadruple, schedule=FixedRate(_datetime.timedelta(hours=12), "kickoff_input"), + "schedule_test", + quadruple, + schedule=FixedRate(_datetime.timedelta(hours=12), "kickoff_input"), ) assert lp.schedule == _schedule_models.Schedule( "kickoff_input", rate=_schedule_models.Schedule.FixedRate(12, _schedule_models.Schedule.FixedRateUnit.HOUR) diff --git a/tests/flytekit/unit/core/test_serialization.py b/tests/flytekit/unit/core/test_serialization.py index 89f9f43b99..6d81ce317d 100644 --- a/tests/flytekit/unit/core/test_serialization.py +++ b/tests/flytekit/unit/core/test_serialization.py @@ -260,7 +260,11 @@ def t5(a: int) -> int: os.environ["FLYTE_INTERNAL_IMAGE"] = "docker.io/default:version" set_flyte_config_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/images.config")) rs = context_manager.SerializationSettings( - project="project", domain="domain", version="version", env=None, image_config=get_image_config(), + project="project", + domain="domain", + version="version", + env=None, + image_config=get_image_config(), ) t1_ser = get_serializable(OrderedDict(), rs, t1) assert t1_ser.container.image == "docker.io/xyz:version" diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index 4dcdaec4b7..271cc3cd62 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -80,7 +80,10 @@ def my_task() -> str: def test_engine_file_output(): - basic_blob_type = _core_types.BlobType(format="", dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE,) + basic_blob_type = _core_types.BlobType( + format="", + dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE, + ) fs = FileAccessProvider(local_sandbox_dir="/tmp/flytetesting") with context_manager.FlyteContext.current_context().new_file_access_context(file_access_provider=fs) as ctx: diff --git a/tests/flytekit/unit/engines/flyte/test_engine.py b/tests/flytekit/unit/engines/flyte/test_engine.py index 452f9beeba..ba7d478cd0 100644 --- a/tests/flytekit/unit/engines/flyte/test_engine.py +++ b/tests/flytekit/unit/engines/flyte/test_engine.py @@ -29,7 +29,10 @@ @pytest.fixture(scope="function", autouse=True) def temp_config(): with TemporaryConfiguration( - os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../../common/configs/local.config",), + os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "../../../common/configs/local.config", + ), internal_overrides={ "image": "myflyteimage:{}".format(os.environ.get("IMAGE_VERSION", "sha")), "project": "myflyteproject", @@ -71,7 +74,10 @@ def test_task_system_failure(): engine.FlyteTask(m).execute(None, {"output_prefix": tmp.name}) doc = errors.ErrorDocument.from_flyte_idl( - utils.load_proto_from_file(errors_pb2.ErrorDocument, os.path.join(tmp.name, constants.ERROR_FILE_NAME),) + utils.load_proto_from_file( + errors_pb2.ErrorDocument, + os.path.join(tmp.name, constants.ERROR_FILE_NAME), + ) ) assert doc.error.code == "SYSTEM:Unknown" assert doc.error.kind == errors.ContainerError.Kind.RECOVERABLE @@ -87,7 +93,10 @@ def test_task_user_failure(): engine.FlyteTask(m).execute(None, {"output_prefix": tmp.name}) doc = errors.ErrorDocument.from_flyte_idl( - utils.load_proto_from_file(errors_pb2.ErrorDocument, os.path.join(tmp.name, constants.ERROR_FILE_NAME),) + utils.load_proto_from_file( + errors_pb2.ErrorDocument, + os.path.join(tmp.name, constants.ERROR_FILE_NAME), + ) ) assert doc.error.code == "USER:Unknown" assert doc.error.kind == errors.ContainerError.Kind.NON_RECOVERABLE @@ -112,7 +121,13 @@ def test_execution_notification_overrides(mock_client_factory): "xd", "xn", _execution_models.ExecutionSpec( - identifier.Identifier(identifier.ResourceType.LAUNCH_PLAN, "project", "domain", "name", "version",), + identifier.Identifier( + identifier.ResourceType.LAUNCH_PLAN, + "project", + "domain", + "name", + "version", + ), _execution_models.ExecutionMetadata(_execution_models.ExecutionMetadata.ExecutionMode.MANUAL, "sdk", 0), disable_all=True, ), @@ -140,7 +155,13 @@ def test_execution_notification_soft_overrides(mock_client_factory): "xd", "xn", _execution_models.ExecutionSpec( - identifier.Identifier(identifier.ResourceType.LAUNCH_PLAN, "project", "domain", "name", "version",), + identifier.Identifier( + identifier.ResourceType.LAUNCH_PLAN, + "project", + "domain", + "name", + "version", + ), _execution_models.ExecutionMetadata(_execution_models.ExecutionMetadata.ExecutionMode.MANUAL, "sdk", 0), notifications=_execution_models.NotificationList([notification]), ), @@ -161,7 +182,12 @@ def test_execution_label_overrides(mock_client_factory): labels = _common_models.Labels({"my": "label"}) engine.FlyteLaunchPlan(m).execute( - "xp", "xd", "xn", literals.LiteralMap({}), notification_overrides=[], label_overrides=labels, + "xp", + "xd", + "xn", + literals.LiteralMap({}), + notification_overrides=[], + label_overrides=labels, ) mock_client.create_execution.assert_called_once_with( @@ -169,7 +195,13 @@ def test_execution_label_overrides(mock_client_factory): "xd", "xn", _execution_models.ExecutionSpec( - identifier.Identifier(identifier.ResourceType.LAUNCH_PLAN, "project", "domain", "name", "version",), + identifier.Identifier( + identifier.ResourceType.LAUNCH_PLAN, + "project", + "domain", + "name", + "version", + ), _execution_models.ExecutionMetadata(_execution_models.ExecutionMetadata.ExecutionMode.MANUAL, "sdk", 0), disable_all=True, labels=labels, @@ -191,7 +223,12 @@ def test_execution_annotation_overrides(mock_client_factory): annotations = _common_models.Annotations({"my": "annotation"}) engine.FlyteLaunchPlan(m).launch( - "xp", "xd", "xn", literals.LiteralMap({}), notification_overrides=[], annotation_overrides=annotations, + "xp", + "xd", + "xn", + literals.LiteralMap({}), + notification_overrides=[], + annotation_overrides=annotations, ) mock_client.create_execution.assert_called_once_with( @@ -199,7 +236,13 @@ def test_execution_annotation_overrides(mock_client_factory): "xd", "xn", _execution_models.ExecutionSpec( - identifier.Identifier(identifier.ResourceType.LAUNCH_PLAN, "project", "domain", "name", "version",), + identifier.Identifier( + identifier.ResourceType.LAUNCH_PLAN, + "project", + "domain", + "name", + "version", + ), _execution_models.ExecutionMetadata(_execution_models.ExecutionMetadata.ExecutionMode.MANUAL, "sdk", 0), disable_all=True, annotations=annotations, @@ -254,12 +297,23 @@ def test_fetch_active_launch_plan(mock_client_factory): def test_get_full_execution_inputs(mock_client_factory): mock_client = MagicMock() mock_client.get_execution_data = MagicMock( - return_value=_execution_models.WorkflowExecutionGetDataResponse(None, None, _INPUT_MAP, _OUTPUT_MAP,) + return_value=_execution_models.WorkflowExecutionGetDataResponse( + None, + None, + _INPUT_MAP, + _OUTPUT_MAP, + ) ) mock_client_factory.return_value = mock_client m = MagicMock() - type(m).id = PropertyMock(return_value=identifier.WorkflowExecutionIdentifier("project", "domain", "name",)) + type(m).id = PropertyMock( + return_value=identifier.WorkflowExecutionIdentifier( + "project", + "domain", + "name", + ) + ) inputs = engine.FlyteWorkflowExecution(m).get_inputs() assert len(inputs.literals) == 1 @@ -280,7 +334,13 @@ def test_get_execution_inputs(mock_client_factory, execution_data_locations): mock_client_factory.return_value = mock_client m = MagicMock() - type(m).id = PropertyMock(return_value=identifier.WorkflowExecutionIdentifier("project", "domain", "name",)) + type(m).id = PropertyMock( + return_value=identifier.WorkflowExecutionIdentifier( + "project", + "domain", + "name", + ) + ) inputs = engine.FlyteWorkflowExecution(m).get_inputs() assert len(inputs.literals) == 1 @@ -299,7 +359,13 @@ def test_get_full_execution_outputs(mock_client_factory): mock_client_factory.return_value = mock_client m = MagicMock() - type(m).id = PropertyMock(return_value=identifier.WorkflowExecutionIdentifier("project", "domain", "name",)) + type(m).id = PropertyMock( + return_value=identifier.WorkflowExecutionIdentifier( + "project", + "domain", + "name", + ) + ) outputs = engine.FlyteWorkflowExecution(m).get_outputs() assert len(outputs.literals) == 1 @@ -320,7 +386,13 @@ def test_get_execution_outputs(mock_client_factory, execution_data_locations): mock_client_factory.return_value = mock_client m = MagicMock() - type(m).id = PropertyMock(return_value=identifier.WorkflowExecutionIdentifier("project", "domain", "name",)) + type(m).id = PropertyMock( + return_value=identifier.WorkflowExecutionIdentifier( + "project", + "domain", + "name", + ) + ) inputs = engine.FlyteWorkflowExecution(m).get_outputs() assert len(inputs.literals) == 1 @@ -334,14 +406,24 @@ def test_get_execution_outputs(mock_client_factory, execution_data_locations): def test_get_full_node_execution_inputs(mock_client_factory): mock_client = MagicMock() mock_client.get_node_execution_data = MagicMock( - return_value=_execution_models.NodeExecutionGetDataResponse(None, None, _INPUT_MAP, _OUTPUT_MAP,) + return_value=_execution_models.NodeExecutionGetDataResponse( + None, + None, + _INPUT_MAP, + _OUTPUT_MAP, + ) ) mock_client_factory.return_value = mock_client m = MagicMock() type(m).id = PropertyMock( return_value=identifier.NodeExecutionIdentifier( - "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), + "node-a", + identifier.WorkflowExecutionIdentifier( + "project", + "domain", + "name", + ), ) ) @@ -350,7 +432,12 @@ def test_get_full_node_execution_inputs(mock_client_factory): assert inputs.literals["a"].scalar.primitive.integer == 1 mock_client.get_node_execution_data.assert_called_once_with( identifier.NodeExecutionIdentifier( - "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), + "node-a", + identifier.WorkflowExecutionIdentifier( + "project", + "domain", + "name", + ), ) ) @@ -368,7 +455,12 @@ def test_get_node_execution_inputs(mock_client_factory, execution_data_locations m = MagicMock() type(m).id = PropertyMock( return_value=identifier.NodeExecutionIdentifier( - "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), + "node-a", + identifier.WorkflowExecutionIdentifier( + "project", + "domain", + "name", + ), ) ) @@ -377,7 +469,12 @@ def test_get_node_execution_inputs(mock_client_factory, execution_data_locations assert inputs.literals["a"].scalar.primitive.integer == 1 mock_client.get_node_execution_data.assert_called_once_with( identifier.NodeExecutionIdentifier( - "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), + "node-a", + identifier.WorkflowExecutionIdentifier( + "project", + "domain", + "name", + ), ) ) @@ -393,7 +490,12 @@ def test_get_full_node_execution_outputs(mock_client_factory): m = MagicMock() type(m).id = PropertyMock( return_value=identifier.NodeExecutionIdentifier( - "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), + "node-a", + identifier.WorkflowExecutionIdentifier( + "project", + "domain", + "name", + ), ) ) @@ -402,7 +504,12 @@ def test_get_full_node_execution_outputs(mock_client_factory): assert outputs.literals["b"].scalar.primitive.integer == 2 mock_client.get_node_execution_data.assert_called_once_with( identifier.NodeExecutionIdentifier( - "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), + "node-a", + identifier.WorkflowExecutionIdentifier( + "project", + "domain", + "name", + ), ) ) @@ -420,7 +527,12 @@ def test_get_node_execution_outputs(mock_client_factory, execution_data_location m = MagicMock() type(m).id = PropertyMock( return_value=identifier.NodeExecutionIdentifier( - "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), + "node-a", + identifier.WorkflowExecutionIdentifier( + "project", + "domain", + "name", + ), ) ) @@ -429,7 +541,12 @@ def test_get_node_execution_outputs(mock_client_factory, execution_data_location assert inputs.literals["b"].scalar.primitive.integer == 2 mock_client.get_node_execution_data.assert_called_once_with( identifier.NodeExecutionIdentifier( - "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), + "node-a", + identifier.WorkflowExecutionIdentifier( + "project", + "domain", + "name", + ), ) ) @@ -445,9 +562,20 @@ def test_get_full_task_execution_inputs(mock_client_factory): m = MagicMock() type(m).id = PropertyMock( return_value=identifier.TaskExecutionIdentifier( - identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "task-name", "version",), + identifier.Identifier( + identifier.ResourceType.TASK, + "project", + "domain", + "task-name", + "version", + ), identifier.NodeExecutionIdentifier( - "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), + "node-a", + identifier.WorkflowExecutionIdentifier( + "project", + "domain", + "name", + ), ), 0, ) @@ -458,9 +586,20 @@ def test_get_full_task_execution_inputs(mock_client_factory): assert inputs.literals["a"].scalar.primitive.integer == 1 mock_client.get_task_execution_data.assert_called_once_with( identifier.TaskExecutionIdentifier( - identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "task-name", "version",), + identifier.Identifier( + identifier.ResourceType.TASK, + "project", + "domain", + "task-name", + "version", + ), identifier.NodeExecutionIdentifier( - "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), + "node-a", + identifier.WorkflowExecutionIdentifier( + "project", + "domain", + "name", + ), ), 0, ) @@ -480,9 +619,20 @@ def test_get_task_execution_inputs(mock_client_factory, execution_data_locations m = MagicMock() type(m).id = PropertyMock( return_value=identifier.TaskExecutionIdentifier( - identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "task-name", "version",), + identifier.Identifier( + identifier.ResourceType.TASK, + "project", + "domain", + "task-name", + "version", + ), identifier.NodeExecutionIdentifier( - "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), + "node-a", + identifier.WorkflowExecutionIdentifier( + "project", + "domain", + "name", + ), ), 0, ) @@ -493,9 +643,20 @@ def test_get_task_execution_inputs(mock_client_factory, execution_data_locations assert inputs.literals["a"].scalar.primitive.integer == 1 mock_client.get_task_execution_data.assert_called_once_with( identifier.TaskExecutionIdentifier( - identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "task-name", "version",), + identifier.Identifier( + identifier.ResourceType.TASK, + "project", + "domain", + "task-name", + "version", + ), identifier.NodeExecutionIdentifier( - "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), + "node-a", + identifier.WorkflowExecutionIdentifier( + "project", + "domain", + "name", + ), ), 0, ) @@ -513,9 +674,20 @@ def test_get_full_task_execution_outputs(mock_client_factory): m = MagicMock() type(m).id = PropertyMock( return_value=identifier.TaskExecutionIdentifier( - identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "task-name", "version",), + identifier.Identifier( + identifier.ResourceType.TASK, + "project", + "domain", + "task-name", + "version", + ), identifier.NodeExecutionIdentifier( - "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), + "node-a", + identifier.WorkflowExecutionIdentifier( + "project", + "domain", + "name", + ), ), 0, ) @@ -526,9 +698,20 @@ def test_get_full_task_execution_outputs(mock_client_factory): assert outputs.literals["b"].scalar.primitive.integer == 2 mock_client.get_task_execution_data.assert_called_once_with( identifier.TaskExecutionIdentifier( - identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "task-name", "version",), + identifier.Identifier( + identifier.ResourceType.TASK, + "project", + "domain", + "task-name", + "version", + ), identifier.NodeExecutionIdentifier( - "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), + "node-a", + identifier.WorkflowExecutionIdentifier( + "project", + "domain", + "name", + ), ), 0, ) @@ -548,9 +731,20 @@ def test_get_task_execution_outputs(mock_client_factory, execution_data_location m = MagicMock() type(m).id = PropertyMock( return_value=identifier.TaskExecutionIdentifier( - identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "task-name", "version",), + identifier.Identifier( + identifier.ResourceType.TASK, + "project", + "domain", + "task-name", + "version", + ), identifier.NodeExecutionIdentifier( - "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), + "node-a", + identifier.WorkflowExecutionIdentifier( + "project", + "domain", + "name", + ), ), 0, ) @@ -561,9 +755,20 @@ def test_get_task_execution_outputs(mock_client_factory, execution_data_location assert inputs.literals["b"].scalar.primitive.integer == 2 mock_client.get_task_execution_data.assert_called_once_with( identifier.TaskExecutionIdentifier( - identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "task-name", "version",), + identifier.Identifier( + identifier.ResourceType.TASK, + "project", + "domain", + "task-name", + "version", + ), identifier.NodeExecutionIdentifier( - "node-a", identifier.WorkflowExecutionIdentifier("project", "domain", "name",), + "node-a", + identifier.WorkflowExecutionIdentifier( + "project", + "domain", + "name", + ), ), 0, ) @@ -573,7 +778,12 @@ def test_get_task_execution_outputs(mock_client_factory, execution_data_location @pytest.mark.parametrize( "tasks", [ - [_task_models.Task(identifier.Identifier(identifier.ResourceType.TASK, "p1", "d1", "n1", "v1"), MagicMock(),)], + [ + _task_models.Task( + identifier.Identifier(identifier.ResourceType.TASK, "p1", "d1", "n1", "v1"), + MagicMock(), + ) + ], [], ], ) diff --git a/tests/flytekit/unit/extras/sqlite3/test_task.py b/tests/flytekit/unit/extras/sqlite3/test_task.py index 9b8e360f69..45dd18e9ae 100644 --- a/tests/flytekit/unit/extras/sqlite3/test_task.py +++ b/tests/flytekit/unit/extras/sqlite3/test_task.py @@ -10,7 +10,12 @@ # This task belongs to test_task_static but is intentionally here to help test tracking tk = SQLite3Task( - "test", query_template="select * from tracks", task_config=SQLite3Config(uri=EXAMPLE_DB, compressed=True,), + "test", + query_template="select * from tracks", + task_config=SQLite3Config( + uri=EXAMPLE_DB, + compressed=True, + ), ) @@ -27,7 +32,10 @@ def test_task_schema(): query_template="select TrackId, Name from tracks limit {{.inputs.limit}}", inputs=kwtypes(limit=int), output_schema_type=FlyteSchema[kwtypes(TrackId=int, Name=str)], - task_config=SQLite3Config(uri=EXAMPLE_DB, compressed=True,), + task_config=SQLite3Config( + uri=EXAMPLE_DB, + compressed=True, + ), ) assert sql_task.output_columns is not None @@ -44,7 +52,10 @@ def my_task(df: pandas.DataFrame) -> int: "test", query_template="select * from tracks limit {{.inputs.limit}}", inputs=kwtypes(limit=int), - task_config=SQLite3Config(uri=EXAMPLE_DB, compressed=True,), + task_config=SQLite3Config( + uri=EXAMPLE_DB, + compressed=True, + ), ) @workflow diff --git a/tests/flytekit/unit/models/core/test_identifier.py b/tests/flytekit/unit/models/core/test_identifier.py index bf00aed216..7ca65daf1a 100644 --- a/tests/flytekit/unit/models/core/test_identifier.py +++ b/tests/flytekit/unit/models/core/test_identifier.py @@ -35,7 +35,10 @@ def test_node_execution_identifier(): def test_task_execution_identifier(): task_id = identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version") wf_exec_id = identifier.WorkflowExecutionIdentifier("project", "domain", "name") - node_exec_id = identifier.NodeExecutionIdentifier("node_id", wf_exec_id,) + node_exec_id = identifier.NodeExecutionIdentifier( + "node_id", + wf_exec_id, + ) obj = identifier.TaskExecutionIdentifier(task_id, node_exec_id, 3) assert obj.retry_attempt == 3 assert obj.task_id == task_id diff --git a/tests/flytekit/unit/models/core/test_types.py b/tests/flytekit/unit/models/core/test_types.py index 744e1f90d5..e7e98cf166 100644 --- a/tests/flytekit/unit/models/core/test_types.py +++ b/tests/flytekit/unit/models/core/test_types.py @@ -9,7 +9,10 @@ def test_blob_dimensionality(): def test_blob_type(): - o = _types.BlobType(format="csv", dimensionality=_types.BlobType.BlobDimensionality.SINGLE,) + o = _types.BlobType( + format="csv", + dimensionality=_types.BlobType.BlobDimensionality.SINGLE, + ) assert o.format == "csv" assert o.dimensionality == _types.BlobType.BlobDimensionality.SINGLE diff --git a/tests/flytekit/unit/models/core/test_workflow.py b/tests/flytekit/unit/models/core/test_workflow.py index b64b87eb0b..2cc2e51420 100644 --- a/tests/flytekit/unit/models/core/test_workflow.py +++ b/tests/flytekit/unit/models/core/test_workflow.py @@ -41,7 +41,12 @@ def test_workflow_template(): {"b": _interface.Variable(int_type, "description2"), "c": _interface.Variable(int_type, "description3")}, ) wf_node = _workflow.Node( - id="some:node:id", metadata=nm, inputs=[], upstream_node_ids=[], output_aliases=[], task_node=task, + id="some:node:id", + metadata=nm, + inputs=[], + upstream_node_ids=[], + output_aliases=[], + task_node=task, ) obj = _workflow.WorkflowTemplate( id=_generic_id, @@ -111,7 +116,12 @@ def test_node_task_with_no_inputs(): task = _workflow.TaskNode(reference_id=_generic_id) obj = _workflow.Node( - id="some:node:id", metadata=nm, inputs=[], upstream_node_ids=[], output_aliases=[], task_node=task, + id="some:node:id", + metadata=nm, + inputs=[], + upstream_node_ids=[], + output_aliases=[], + task_node=task, ) assert obj.target == task assert obj.id == "some:node:id" diff --git a/tests/flytekit/unit/models/sagemaker/test_hpo_job.py b/tests/flytekit/unit/models/sagemaker/test_hpo_job.py index eab5f209b4..4b38672300 100644 --- a/tests/flytekit/unit/models/sagemaker/test_hpo_job.py +++ b/tests/flytekit/unit/models/sagemaker/test_hpo_job.py @@ -38,7 +38,10 @@ def test_hyperparameter_tuning_job(): input_mode=training_job.InputMode.FILE, input_content_type=training_job.InputContentType.TEXT_CSV, ) - tj = training_job.TrainingJob(training_job_resource_config=rc, algorithm_specification=alg,) + tj = training_job.TrainingJob( + training_job_resource_config=rc, + algorithm_specification=alg, + ) hpo = hpo_job.HyperparameterTuningJob(max_number_of_training_jobs=10, max_parallel_training_jobs=5, training_job=tj) hpo2 = hpo_job.HyperparameterTuningJob.from_flyte_idl(hpo.to_flyte_idl()) diff --git a/tests/flytekit/unit/models/sagemaker/test_training_job.py b/tests/flytekit/unit/models/sagemaker/test_training_job.py index cc5e5f4d3d..271669b16c 100644 --- a/tests/flytekit/unit/models/sagemaker/test_training_job.py +++ b/tests/flytekit/unit/models/sagemaker/test_training_job.py @@ -70,7 +70,10 @@ def test_training_job(): input_mode=training_job.InputMode.FILE, input_content_type=training_job.InputContentType.TEXT_CSV, ) - tj = training_job.TrainingJob(training_job_resource_config=rc, algorithm_specification=alg,) + tj = training_job.TrainingJob( + training_job_resource_config=rc, + algorithm_specification=alg, + ) tj2 = training_job.TrainingJob.from_flyte_idl(tj.to_flyte_idl()) # checking tj == tj2 would return false because we don't have the __eq__ magic method defined diff --git a/tests/flytekit/unit/models/test_dynamic_job.py b/tests/flytekit/unit/models/test_dynamic_job.py index 0a9dff117f..1aff800abd 100644 --- a/tests/flytekit/unit/models/test_dynamic_job.py +++ b/tests/flytekit/unit/models/test_dynamic_job.py @@ -20,11 +20,18 @@ interfaces, _array_job.ArrayJob(2, 2, 2).to_dict(), container=_task.Container( - "my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], resources, {"a": "b"}, {"d": "e"}, + "my_image", + ["this", "is", "a", "cmd"], + ["this", "is", "an", "arg"], + resources, + {"a": "b"}, + {"d": "e"}, ), ) for task_metadata, interfaces, resources in product( - parameterizers.LIST_OF_TASK_METADATA, parameterizers.LIST_OF_INTERFACES, parameterizers.LIST_OF_RESOURCES, + parameterizers.LIST_OF_TASK_METADATA, + parameterizers.LIST_OF_INTERFACES, + parameterizers.LIST_OF_RESOURCES, ) ] @@ -34,7 +41,12 @@ def test_future_task_document(task): rs = _literals.RetryStrategy(0) nm = _workflow.NodeMetadata("node-name", _timedelta(minutes=10), rs) n = _workflow.Node( - id="id", metadata=nm, inputs=[], upstream_node_ids=[], output_aliases=[], task_node=_workflow.TaskNode(task.id), + id="id", + metadata=nm, + inputs=[], + upstream_node_ids=[], + output_aliases=[], + task_node=_workflow.TaskNode(task.id), ) n.to_flyte_idl() doc = _dynamic_job.DynamicJobSpec( diff --git a/tests/flytekit/unit/models/test_launch_plan.py b/tests/flytekit/unit/models/test_launch_plan.py index 788c12119b..8c2b7db8a7 100644 --- a/tests/flytekit/unit/models/test_launch_plan.py +++ b/tests/flytekit/unit/models/test_launch_plan.py @@ -26,7 +26,9 @@ def test_lp_closure(): parameter_map.to_flyte_idl() variable_map = interface.VariableMap({"vvv": v}) obj = launch_plan.LaunchPlanClosure( - state=launch_plan.LaunchPlanState.ACTIVE, expected_inputs=parameter_map, expected_outputs=variable_map, + state=launch_plan.LaunchPlanState.ACTIVE, + expected_inputs=parameter_map, + expected_outputs=variable_map, ) assert obj.expected_inputs == parameter_map assert obj.expected_outputs == variable_map diff --git a/tests/flytekit/unit/models/test_schedule.py b/tests/flytekit/unit/models/test_schedule.py index 8bade49fcb..b7fad79124 100644 --- a/tests/flytekit/unit/models/test_schedule.py +++ b/tests/flytekit/unit/models/test_schedule.py @@ -37,7 +37,8 @@ def test_schedule_fixed_rate(): @_pytest.mark.parametrize( - "offset", [None, "P1D"], + "offset", + [None, "P1D"], ) def test_schedule_cron_schedule(offset): cs = _schedule.Schedule.CronSchedule("days", offset) diff --git a/tests/flytekit/unit/models/test_tasks.py b/tests/flytekit/unit/models/test_tasks.py index 0665f34c9f..0abd3a0402 100644 --- a/tests/flytekit/unit/models/test_tasks.py +++ b/tests/flytekit/unit/models/test_tasks.py @@ -95,7 +95,12 @@ def test_task_template(in_tuple): interfaces, {"a": 1, "b": {"c": 2, "d": 3}}, container=task.Container( - "my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], resources, {"a": "b"}, {"d": "e"}, + "my_image", + ["this", "is", "a", "cmd"], + ["this", "is", "an", "arg"], + resources, + {"a": "b"}, + {"d": "e"}, ), config={"a": "b"}, ) @@ -143,7 +148,8 @@ def test_task_template_security_context(sec_ctx): @pytest.mark.parametrize("task_closure", parameterizers.LIST_OF_TASK_CLOSURES) def test_task(task_closure): obj = task.Task( - identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version"), task_closure, + identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version"), + task_closure, ) assert obj.id.project == "project" assert obj.id.domain == "domain" @@ -156,7 +162,12 @@ def test_task(task_closure): @pytest.mark.parametrize("resources", parameterizers.LIST_OF_RESOURCES) def test_container(resources): obj = task.Container( - "my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], resources, {"a": "b"}, {"d": "e"}, + "my_image", + ["this", "is", "a", "cmd"], + ["this", "is", "an", "arg"], + resources, + {"a": "b"}, + {"d": "e"}, ) obj.image == "my_image" obj.command == ["this", "is", "a", "cmd"] @@ -182,7 +193,10 @@ def test_sidecar_task(): def test_dataloadingconfig(): dlc = task.DataLoadingConfig( - "s3://input/path", "s3://output/path", True, task.DataLoadingConfig.LITERALMAP_FORMAT_YAML, + "s3://input/path", + "s3://output/path", + True, + task.DataLoadingConfig.LITERALMAP_FORMAT_YAML, ) dlc2 = task.DataLoadingConfig.from_flyte_idl(dlc.to_flyte_idl()) assert dlc2 == dlc diff --git a/tests/flytekit/unit/models/test_workflow_closure.py b/tests/flytekit/unit/models/test_workflow_closure.py index 3368387a7f..3e19a80657 100644 --- a/tests/flytekit/unit/models/test_workflow_closure.py +++ b/tests/flytekit/unit/models/test_workflow_closure.py @@ -17,7 +17,8 @@ def test_workflow_closure(): ) b0 = _literals.Binding( - "a", _literals.BindingData(scalar=_literals.Scalar(primitive=_literals.Primitive(integer=5))), + "a", + _literals.BindingData(scalar=_literals.Scalar(primitive=_literals.Primitive(integer=5))), ) b1 = _literals.Binding("b", _literals.BindingData(promise=_types.OutputReference("my_node", "b"))) b2 = _literals.Binding("c", _literals.BindingData(promise=_types.OutputReference("my_node", "c"))) @@ -46,13 +47,23 @@ def test_workflow_closure(): typed_interface, {"a": 1, "b": {"c": 2, "d": 3}}, container=_task.Container( - "my_image", ["this", "is", "a", "cmd"], ["this", "is", "an", "arg"], resources, {}, {}, + "my_image", + ["this", "is", "a", "cmd"], + ["this", "is", "an", "arg"], + resources, + {}, + {}, ), ) task_node = _workflow.TaskNode(task.id) node = _workflow.Node( - id="my_node", metadata=node_metadata, inputs=[b0], upstream_node_ids=[], output_aliases=[], task_node=task_node, + id="my_node", + metadata=node_metadata, + inputs=[b0], + upstream_node_ids=[], + output_aliases=[], + task_node=task_node, ) template = _workflow.WorkflowTemplate( diff --git a/tests/flytekit/unit/sdk/tasks/test_dynamic_sidecar_tasks.py b/tests/flytekit/unit/sdk/tasks/test_dynamic_sidecar_tasks.py index 532a92d5a8..2de75de96c 100644 --- a/tests/flytekit/unit/sdk/tasks/test_dynamic_sidecar_tasks.py +++ b/tests/flytekit/unit/sdk/tasks/test_dynamic_sidecar_tasks.py @@ -12,9 +12,18 @@ def get_pod_spec(): a_container = generated_pb2.Container(name="main") a_container.command.extend(["foo", "bar"]) - a_container.volumeMounts.extend([generated_pb2.VolumeMount(name="scratch", mountPath="/scratch",)]) + a_container.volumeMounts.extend( + [ + generated_pb2.VolumeMount( + name="scratch", + mountPath="/scratch", + ) + ] + ) - pod_spec = generated_pb2.PodSpec(restartPolicy="Never",) + pod_spec = generated_pb2.PodSpec( + restartPolicy="Never", + ) pod_spec.containers.extend([a_container, generated_pb2.Container(name="sidecar")]) return pod_spec diff --git a/tests/flytekit/unit/sdk/tasks/test_hive_tasks.py b/tests/flytekit/unit/sdk/tasks/test_hive_tasks.py index 63542e3331..e81edec5de 100644 --- a/tests/flytekit/unit/sdk/tasks/test_hive_tasks.py +++ b/tests/flytekit/unit/sdk/tasks/test_hive_tasks.py @@ -36,7 +36,9 @@ def sample_hive_task_no_queries(wf_params): @qubole_hive_task( - cache_version="1", cluster_label=_six.text_type("cluster_label"), tags=[], + cache_version="1", + cluster_label=_six.text_type("cluster_label"), + tags=[], ) def sample_qubole_hive_task_no_input(wf_params): return _six.text_type("select 5") @@ -44,7 +46,9 @@ def sample_qubole_hive_task_no_input(wf_params): @inputs(in1=Types.Integer) @qubole_hive_task( - cache_version="1", cluster_label=_six.text_type("cluster_label"), tags=[_six.text_type("tag1")], + cache_version="1", + cluster_label=_six.text_type("cluster_label"), + tags=[_six.text_type("tag1")], ) def sample_qubole_hive_task(wf_params, in1): return _six.text_type("select ") + _six.text_type(in1) diff --git a/tests/flytekit/unit/sdk/tasks/test_sagemaker_tasks.py b/tests/flytekit/unit/sdk/tasks/test_sagemaker_tasks.py index f41c547ad5..19b4b8e2a1 100644 --- a/tests/flytekit/unit/sdk/tasks/test_sagemaker_tasks.py +++ b/tests/flytekit/unit/sdk/tasks/test_sagemaker_tasks.py @@ -83,7 +83,9 @@ def test_builtin_algorithm_training_job_task(): builtin_algorithm_training_job_task = SdkBuiltinAlgorithmTrainingJobTask( training_job_resource_config=TrainingJobResourceConfig( - instance_type="ml.m4.xlarge", instance_count=1, volume_size_in_gb=25, + instance_type="ml.m4.xlarge", + instance_count=1, + volume_size_in_gb=25, ), algorithm_specification=AlgorithmSpecification( input_mode=InputMode.FILE, @@ -100,7 +102,10 @@ def test_builtin_algorithm_training_job_task(): assert isinstance(builtin_algorithm_training_job_task, _sdk_task.SdkTask) assert builtin_algorithm_training_job_task.interface.inputs["train"].description == "" assert builtin_algorithm_training_job_task.interface.inputs["train"].type == _idl_types.LiteralType( - blob=_core_types.BlobType(format="csv", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART,) + blob=_core_types.BlobType( + format="csv", + dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, + ) ) assert ( builtin_algorithm_training_job_task.interface.inputs["train"].type @@ -112,7 +117,10 @@ def test_builtin_algorithm_training_job_task(): == _sdk_types.Types.MultiPartCSV.to_flyte_literal_type() ) assert builtin_algorithm_training_job_task.interface.inputs["train"].type == _idl_types.LiteralType( - blob=_core_types.BlobType(format="csv", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART,) + blob=_core_types.BlobType( + format="csv", + dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, + ) ) assert builtin_algorithm_training_job_task.interface.inputs["static_hyperparameters"].description == "" assert ( @@ -133,13 +141,16 @@ def test_builtin_algorithm_training_job_task(): assert "metricDefinitions" not in builtin_algorithm_training_job_task.custom["algorithmSpecification"].keys() ParseDict( - builtin_algorithm_training_job_task.custom["trainingJobResourceConfig"], _pb2_TrainingJobResourceConfig(), + builtin_algorithm_training_job_task.custom["trainingJobResourceConfig"], + _pb2_TrainingJobResourceConfig(), ) # fails the test if it cannot be parsed builtin_algorithm_training_job_task2 = SdkBuiltinAlgorithmTrainingJobTask( training_job_resource_config=TrainingJobResourceConfig( - instance_type="ml.m4.xlarge", instance_count=1, volume_size_in_gb=25, + instance_type="ml.m4.xlarge", + instance_count=1, + volume_size_in_gb=25, ), algorithm_specification=AlgorithmSpecification( input_mode=InputMode.FILE, @@ -175,7 +186,10 @@ def test_simple_hpo_job_task(): == _sdk_types.Types.MultiPartCSV.to_flyte_literal_type() ) assert simple_xgboost_hpo_job_task.interface.inputs["train"].type == _idl_types.LiteralType( - blob=_core_types.BlobType(format="csv", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART,) + blob=_core_types.BlobType( + format="csv", + dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, + ) ) assert simple_xgboost_hpo_job_task.interface.inputs["validation"].description == "" assert ( @@ -183,7 +197,10 @@ def test_simple_hpo_job_task(): == _sdk_types.Types.MultiPartCSV.to_flyte_literal_type() ) assert simple_xgboost_hpo_job_task.interface.inputs["validation"].type == _idl_types.LiteralType( - blob=_core_types.BlobType(format="csv", dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART,) + blob=_core_types.BlobType( + format="csv", + dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART, + ) ) assert simple_xgboost_hpo_job_task.interface.inputs["static_hyperparameters"].description == "" assert ( @@ -227,7 +244,9 @@ def test_custom_training_job(): @outputs(model=Types.Blob) @custom_training_job_task( training_job_resource_config=TrainingJobResourceConfig( - instance_type="ml.m4.xlarge", instance_count=1, volume_size_in_gb=25, + instance_type="ml.m4.xlarge", + instance_count=1, + volume_size_in_gb=25, ), algorithm_specification=AlgorithmSpecification( input_mode=InputMode.FILE, @@ -252,7 +271,8 @@ class MyWf(object): default=_HyperparameterTuningJobConfig( tuning_strategy=HyperparameterTuningStrategy.BAYESIAN, tuning_objective=HyperparameterTuningObjective( - objective_type=HyperparameterTuningObjectiveType.MINIMIZE, metric_name="validation:error", + objective_type=HyperparameterTuningObjectiveType.MINIMIZE, + metric_name="validation:error", ), training_job_early_stopping_type=TrainingJobEarlyStoppingType.AUTO, ), @@ -309,7 +329,9 @@ def setUp(self): @outputs(model=Types.Blob) @custom_training_job_task( training_job_resource_config=TrainingJobResourceConfig( - instance_type="ml.m4.xlarge", instance_count=1, volume_size_in_gb=25, + instance_type="ml.m4.xlarge", + instance_count=1, + volume_size_in_gb=25, ), algorithm_specification=AlgorithmSpecification( input_mode=InputMode.FILE, @@ -360,7 +382,9 @@ def setUp(self): @outputs(model=Types.Blob) @custom_training_job_task( training_job_resource_config=TrainingJobResourceConfig( - instance_type="ml.m4.xlarge", instance_count=2, volume_size_in_gb=25, + instance_type="ml.m4.xlarge", + instance_count=2, + volume_size_in_gb=25, ), algorithm_specification=AlgorithmSpecification( input_mode=InputMode.FILE, @@ -496,7 +520,9 @@ def test_if_wf_param_has_dist_context(self): @outputs(model=Types.Blob) @custom_training_job_task( training_job_resource_config=TrainingJobResourceConfig( - instance_type="ml.m4.xlarge", instance_count=2, volume_size_in_gb=25, + instance_type="ml.m4.xlarge", + instance_count=2, + volume_size_in_gb=25, ), algorithm_specification=AlgorithmSpecification( input_mode=InputMode.FILE, diff --git a/tests/flytekit/unit/sdk/tasks/test_sidecar_tasks.py b/tests/flytekit/unit/sdk/tasks/test_sidecar_tasks.py index a0fbf1f8a4..0ca8011beb 100644 --- a/tests/flytekit/unit/sdk/tasks/test_sidecar_tasks.py +++ b/tests/flytekit/unit/sdk/tasks/test_sidecar_tasks.py @@ -10,11 +10,22 @@ def get_pod_spec(): - a_container = generated_pb2.Container(name="a container",) + a_container = generated_pb2.Container( + name="a container", + ) a_container.command.extend(["fee", "fi", "fo", "fum"]) - a_container.volumeMounts.extend([generated_pb2.VolumeMount(name="volume mount", mountPath="some/where",)]) + a_container.volumeMounts.extend( + [ + generated_pb2.VolumeMount( + name="volume mount", + mountPath="some/where", + ) + ] + ) - pod_spec = generated_pb2.PodSpec(restartPolicy="OnFailure",) + pod_spec = generated_pb2.PodSpec( + restartPolicy="OnFailure", + ) pod_spec.containers.extend([a_container, generated_pb2.Container(name="another container")]) return pod_spec diff --git a/tests/flytekit/unit/sdk/tasks/test_tasks.py b/tests/flytekit/unit/sdk/tasks/test_tasks.py index a21ffaed04..33e0287199 100644 --- a/tests/flytekit/unit/sdk/tasks/test_tasks.py +++ b/tests/flytekit/unit/sdk/tasks/test_tasks.py @@ -43,7 +43,10 @@ def test_default_python_task(): def test_default_resources(): with _configuration.TemporaryConfiguration( - _os.path.join(_os.path.dirname(_os.path.realpath(__file__)), "../../configuration/configs/good.config",) + _os.path.join( + _os.path.dirname(_os.path.realpath(__file__)), + "../../configuration/configs/good.config", + ) ): @inputs(in1=Types.Integer) @@ -69,7 +72,10 @@ def default_task2(wf_params, in1, out1): def test_overriden_resources(): with _configuration.TemporaryConfiguration( - _os.path.join(_os.path.dirname(_os.path.realpath(__file__)), "../../configuration/configs/good.config",) + _os.path.join( + _os.path.dirname(_os.path.realpath(__file__)), + "../../configuration/configs/good.config", + ) ): @inputs(in1=Types.Integer) diff --git a/tests/flytekit/unit/test_plugins.py b/tests/flytekit/unit/test_plugins.py index 1a5fb33ccf..c6680e34da 100644 --- a/tests/flytekit/unit/test_plugins.py +++ b/tests/flytekit/unit/test_plugins.py @@ -26,7 +26,10 @@ def test_schema_plugin(): @pytest.mark.run(order=2) def test_sidecar_plugin(): assert isinstance(plugins.k8s.io.api.core.v1.generated_pb2, lazy_loader._LazyLoadModule) - assert isinstance(plugins.k8s.io.apimachinery.pkg.api.resource.generated_pb2, lazy_loader._LazyLoadModule,) + assert isinstance( + plugins.k8s.io.apimachinery.pkg.api.resource.generated_pb2, + lazy_loader._LazyLoadModule, + ) import k8s.io.api.core.v1.generated_pb2 import k8s.io.apimachinery.pkg.api.resource.generated_pb2 diff --git a/tests/flytekit/unit/type_engines/default/test_flyte_type_engine.py b/tests/flytekit/unit/type_engines/default/test_flyte_type_engine.py index b37ce8981a..8157739456 100644 --- a/tests/flytekit/unit/type_engines/default/test_flyte_type_engine.py +++ b/tests/flytekit/unit/type_engines/default/test_flyte_type_engine.py @@ -55,7 +55,11 @@ def test_infer_proto_from_literal(): _literal_models.Literal( scalar=_literal_models.Scalar( binary=_literal_models.Binary( - value="", tag="{}{}".format(_proto.Protobuf.TAG_PREFIX, "flyteidl.core.errors_pb2.ContainerError",), + value="", + tag="{}{}".format( + _proto.Protobuf.TAG_PREFIX, + "flyteidl.core.errors_pb2.ContainerError", + ), ) ) )