diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py index c3c0833295..25c3b75b9b 100644 --- a/flytekit/clients/raw.py +++ b/flytekit/clients/raw.py @@ -125,33 +125,35 @@ def handler(*args, **kwargs): """ max_retries = 3 max_wait_time = 1000 - try: - for i in range(max_retries): - try: - return fn(*args, **kwargs) - except _RpcError as e: - if e.code() == _GrpcStatusCode.UNAUTHENTICATED: - # Always retry auth errors. - if i == (max_retries - 1): - # Exit the loop and wrap the authentication error. - raise _user_exceptions.FlyteAuthenticationException(_six.text_type(e)) - cli_logger.error(f"Unauthenticated RPC error {e}, refreshing credentials and retrying\n") - refresh_handler_fn = _get_refresh_handler(_creds_config.AUTH_MODE.get()) - refresh_handler_fn(args[0]) + + for i in range(max_retries): + try: + return fn(*args, **kwargs) + except _RpcError as e: + if e.code() == _GrpcStatusCode.UNAUTHENTICATED: + # Always retry auth errors. + if i == (max_retries - 1): + # Exit the loop and wrap the authentication error. + raise _user_exceptions.FlyteAuthenticationException(_six.text_type(e)) + cli_logger.error(f"Unauthenticated RPC error {e}, refreshing credentials and retrying\n") + refresh_handler_fn = _get_refresh_handler(_creds_config.AUTH_MODE.get()) + refresh_handler_fn(args[0]) + # There are two cases that we should throw error immediately + # 1. Entity already exists when we register entity + # 2. Entity not found when we fetch entity + elif e.code() == _GrpcStatusCode.ALREADY_EXISTS: + raise _user_exceptions.FlyteEntityAlreadyExistsException(e) + elif e.code() == _GrpcStatusCode.NOT_FOUND: + raise _user_exceptions.FlyteEntityNotExistException(e) + else: + # No more retries if retry=False or max_retries reached. + if (retry is False) or i == (max_retries - 1): + raise else: - # No more retries if retry=False or max_retries reached. - if (retry is False) or i == (max_retries - 1): - raise - else: - # Retry: Start with 200ms wait-time and exponentially back-off up to 1 second. - wait_time = min(200 * (2 ** i), max_wait_time) - cli_logger.error(f"Non-auth RPC error {e}, sleeping {wait_time}ms and retrying") - time.sleep(wait_time / 1000) - except _RpcError as e: - if e.code() == _GrpcStatusCode.ALREADY_EXISTS: - raise _user_exceptions.FlyteEntityAlreadyExistsException(_six.text_type(e)) - else: - raise + # Retry: Start with 200ms wait-time and exponentially back-off up to 1 second. + wait_time = min(200 * (2 ** i), max_wait_time) + cli_logger.error(f"Non-auth RPC error {e}, sleeping {wait_time}ms and retrying") + time.sleep(wait_time / 1000) return handler diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 39a7be762a..925f79c770 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -543,7 +543,7 @@ def guess_python_type(cls, flyte_type: LiteralType) -> type: except ValueError: logger.debug(f"Skipping transformer {transformer.name} for {flyte_type}") - # Because the dataclass transformer is handled explicity in the get_transformer code, we have to handle it + # Because the dataclass transformer is handled explicitly in the get_transformer code, we have to handle it # separately here too. try: return cls._DATACLASS_TRANSFORMER.guess_python_type(literal_type=flyte_type) diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 41a128caf9..d9fcd4a0ed 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -1,6 +1,7 @@ """Module defining main Flyte backend entrypoint.""" from __future__ import annotations +import logging import os import time import typing @@ -15,9 +16,12 @@ from flytekit.clients.friendly import SynchronousFlyteClient from flytekit.common import utils as common_utils +from flytekit.common.exceptions.user import FlyteEntityAlreadyExistsException, FlyteEntityNotExistException +from flytekit.configuration import internal from flytekit.configuration import platform as platform_config from flytekit.configuration import sdk as sdk_config from flytekit.configuration import set_flyte_config_file +from flytekit.core import context_manager from flytekit.core.interface import Interface from flytekit.loggers import remote_logger from flytekit.models import filters as filter_models @@ -202,7 +206,6 @@ def __init__( raise user_exceptions.FlyteAssertion("Cannot find flyte admin url in config file.") self._client = SynchronousFlyteClient(flyte_admin_url, insecure=insecure, credentials=grpc_credentials) - # read config files, env vars, host, ssl options for admin client self._flyte_admin_url = flyte_admin_url self._insecure = insecure @@ -520,6 +523,8 @@ def _serialize( domain or self.default_domain, version or self.version, self.image_config, + # https://github.com/flyteorg/flyte/issues/1359 + env={internal.IMAGE.env_var: self.image_config.default_image.full}, ), entity=entity, ) @@ -604,6 +609,24 @@ def _( ) return self.fetch_launch_plan(**resolved_identifiers) + def _register_entity_if_not_exists(self, entity: WorkflowBase, resolved_identifiers_dict: dict): + # Try to register all the entity in WorkflowBase including LaunchPlan, PythonTask, or subworkflow. + node_identifiers_dict = deepcopy(resolved_identifiers_dict) + for node in entity.nodes: + try: + node_identifiers_dict["name"] = node.flyte_entity.name + if isinstance(node.flyte_entity, WorkflowBase): + self._register_entity_if_not_exists(node.flyte_entity, node_identifiers_dict) + self.register(node.flyte_entity, **node_identifiers_dict) + elif isinstance(node.flyte_entity, PythonTask) or isinstance(node.flyte_entity, LaunchPlan): + self.register(node.flyte_entity, **node_identifiers_dict) + else: + raise NotImplementedError(f"We don't support registering this kind of entity: {node.flyte_entity}") + except FlyteEntityAlreadyExistsException: + logging.info(f"{entity.name} already exists") + except Exception as e: + logging.info(f"Failed to register entity {entity.name} with error {e}") + #################### # Execute Entities # #################### @@ -884,11 +907,23 @@ def _( """Execute an @workflow-decorated function.""" resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version) resolved_identifiers_dict = asdict(resolved_identifiers) + try: flyte_workflow: FlyteWorkflow = self.fetch_workflow(**resolved_identifiers_dict) - except Exception: + except FlyteEntityNotExistException: + logging.info("Try to register FlyteWorkflow because it wasn't found in Flyte Admin!") + self._register_entity_if_not_exists(entity, resolved_identifiers_dict) flyte_workflow: FlyteWorkflow = self.register(entity, **resolved_identifiers_dict) flyte_workflow.guessed_python_interface = entity.python_interface + + ctx = context_manager.FlyteContext.current_context() + try: + self.fetch_launch_plan(**resolved_identifiers_dict) + except FlyteEntityNotExistException: + logging.info("Try to register default launch plan because it wasn't found in Flyte Admin!") + default_lp = LaunchPlan.get_default_launch_plan(ctx, entity) + self.register(default_lp, **resolved_identifiers_dict) + return self.execute( flyte_workflow, inputs, diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index 2b61f220c6..c00564a5f3 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -9,7 +9,7 @@ import pytest from flytekit import kwtypes -from flytekit.common.exceptions.user import FlyteAssertion +from flytekit.common.exceptions.user import FlyteAssertion, FlyteEntityNotExistException from flytekit.core.launch_plan import LaunchPlan from flytekit.extras.sqlite3.task import SQLite3Config, SQLite3Task from flytekit.remote.remote import FlyteRemote @@ -238,8 +238,8 @@ def test_execute_python_workflow_list_of_floats(flyteclient, flyte_workflows_reg # make sure the task name is the same as the name used during registration my_wf._name = my_wf.name.replace("mock_flyte_repo.", "") - remote = FlyteRemote.from_config(PROJECT, "development") + xs: typing.List[float] = [42.24, 999.1, 0.0001] execution = remote.execute(my_wf, inputs={"xs": xs}, version=f"v{VERSION}", wait=True) assert execution.outputs["o0"] == "[42.24, 999.1, 0.0001]" @@ -282,3 +282,29 @@ def test_execute_joblib_workflow(flyteclient, flyte_workflows_register, flyte_re output_obj = joblib.load(joblib_output.path) assert execution.outputs["o0"].extension() == "joblib" assert output_obj == input_obj + + +def test_execute_with_default_launch_plan(flyteclient, flyte_workflows_register, flyte_remote_env): + from mock_flyte_repo.workflows.basic.subworkflows import parent_wf + + # make sure the task name is the same as the name used during registration + parent_wf._name = parent_wf.name.replace("mock_flyte_repo.", "") + + remote = FlyteRemote.from_config(PROJECT, "development") + execution = remote.execute(parent_wf, {"a": 101}, version=f"v{VERSION}", wait=True) + # check node execution inputs and outputs + assert execution.node_executions["n0"].inputs == {"a": 101} + assert execution.node_executions["n0"].outputs == {"t1_int_output": 103, "c": "world"} + assert execution.node_executions["n1"].inputs == {"a": 103} + assert execution.node_executions["n1"].outputs == {"o0": "world", "o1": "world"} + + # check subworkflow task execution inputs and outputs + subworkflow_node_executions = execution.node_executions["n1"].subworkflow_node_executions + subworkflow_node_executions["n1-0-n0"].inputs == {"a": 103} + subworkflow_node_executions["n1-0-n1"].outputs == {"t1_int_output": 107, "c": "world"} + + +def test_fetch_not_exist_launch_plan(flyteclient): + remote = FlyteRemote.from_config(PROJECT, "development") + with pytest.raises(FlyteEntityNotExistException): + remote.fetch_launch_plan(name="workflows.basic.list_float_wf.fake_wf", version=f"v{VERSION}")