From 59aa29c932712de92ecfdba3a786b30d4f796faa Mon Sep 17 00:00:00 2001 From: cosmicBboy Date: Wed, 21 Jul 2021 17:31:21 -0400 Subject: [PATCH] move version from __init__ to methods Signed-off-by: cosmicBboy --- flytekit/remote/remote.py | 339 ++++++++++++------ .../integration/remote/test_remote.py | 42 +-- tests/flytekit/unit/remote/test_remote.py | 4 +- 3 files changed, 256 insertions(+), 129 deletions(-) diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 391a93813d5..e4b08d8beb9 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -26,13 +26,12 @@ from flytekit.common.exceptions import user as user_exceptions from flytekit.common.translator import FlyteControlPlaneEntity, FlyteLocalEntity, get_serializable from flytekit.configuration import auth as auth_config -from flytekit.configuration.internal import DOMAIN, PROJECT, VERSION +from flytekit.configuration.internal import DOMAIN, PROJECT from flytekit.core.base_task import PythonTask from flytekit.core.context_manager import FlyteContextManager, ImageConfig, SerializationSettings, get_image_config from flytekit.core.launch_plan import LaunchPlan from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import WorkflowBase -from flytekit.interfaces.data import data_proxy from flytekit.models import common as common_models from flytekit.models import launch_plan as launch_plan_models from flytekit.models import literals as literal_models @@ -101,42 +100,31 @@ class FlyteRemote(object): @staticmethod def from_environment( - project: typing.Optional[str] = None, domain: typing.Optional[str] = None, version: typing.Optional[str] = None + default_project: typing.Optional[str] = None, default_domain: typing.Optional[str] = None ) -> FlyteRemote: - project = project or PROJECT.get() - domain = domain or DOMAIN.get() - version = version or VERSION.get() - auth_role = common_models.AuthRole( - assumable_iam_role=auth_config.ASSUMABLE_IAM_ROLE.get(), - kubernetes_service_account=auth_config.KUBERNETES_SERVICE_ACCOUNT.get(), - ) - raw_output_data_prefix = auth_config.RAW_OUTPUT_DATA_PREFIX.get() - raw_output_data_config = ( - common_models.RawOutputDataConfig(raw_output_data_prefix) if raw_output_data_prefix else None - ) - - image_config = get_image_config() - return FlyteRemote( - project=project, - domain=domain, - version=version, + default_project=default_project or PROJECT.get(), + default_domain=default_domain or DOMAIN.get(), flyte_admin_url=platform_config.URL.get(), insecure=platform_config.INSECURE.get(), - auth_role=auth_role, + auth_role=common_models.AuthRole( + assumable_iam_role=auth_config.ASSUMABLE_IAM_ROLE.get(), + kubernetes_service_account=auth_config.KUBERNETES_SERVICE_ACCOUNT.get(), + ), notifications=None, labels=None, annotations=None, - image_config=image_config, - raw_output_data_config=raw_output_data_config, + image_config=get_image_config(), + raw_output_data_config=( + common_models.RawOutputDataConfig(raw_output_data_prefix) if raw_output_data_prefix else None + ), ) def __init__( self, - project: str, - domain: str, - version: str, + default_project: str, + default_domain: str, flyte_admin_url: str, insecure: bool, auth_role: typing.Optional[common_models.AuthRole] = None, @@ -148,15 +136,11 @@ def __init__( ): remote_logger.warning("This feature is still in beta. Its interface and UX is subject to change.") - # TODO: figure out what config/metadata needs to be loaded into the FlyteRemote object at initialization - self._client = SynchronousFlyteClient(flyte_admin_url, insecure=insecure) - # - read config files, env vars - # - host, ssl options for admin client - self.project = project - self.domain = domain - self.version = version + # read config files, env vars, host, ssl options for admin client + self.default_project = default_project + self.default_domain = default_domain self.image_config = image_config self.auth_role = auth_role self.notifications = notifications @@ -167,35 +151,39 @@ def __init__( # TODO: Reconsider whether we want this. Probably best to not cache. self.serialized_entity_cache = OrderedDict() - self.serialization_settings = SerializationSettings( - self.project, - self.domain, - self.version, - self.image_config, - ) - @property def client(self) -> SynchronousFlyteClient: + """Return a SynchronousFlyteClient for additional operations.""" return self._client + @property + def version(self) -> str: + """Get a randomly generated version string.""" + return uuid.uuid4().hex[:30] + str(int(time.time())) + def with_overrides( self, - project=None, - domain=None, - version=None, - auth_role=None, - notifications=None, - labels=None, - annotations=None, + default_project: str = None, + default_domain: str = None, + flyte_admin_url: str = None, + insecure: bool = None, + auth_role: typing.Optional[common_models.AuthRole] = None, + notifications: typing.Optional[typing.List[common_models.Notification]] = None, + labels: typing.Optional[common_models.Labels] = None, + annotations: typing.Optional[common_models.Annotations] = None, + image_config: typing.Optional[ImageConfig] = None, + raw_output_data_config: typing.Optional[common_models.RawOutputDataConfig] = None, ): """Create a copy of the remote object, overriding the specified attributes.""" new_remote = deepcopy(self) - if project: - new_remote.project = project - if domain: - new_remote.domain = domain - if version: - new_remote.version = version + if default_project: + new_remote.default_project = default_project + if default_domain: + new_remote.default_domain = default_domain + if flyte_admin_url: + new_remote.flyte_admin_url = flyte_admin_url + if insecure: + new_remote.insecure = insecure if auth_role: new_remote.auth_role = auth_role if notifications: @@ -204,20 +192,40 @@ def with_overrides( new_remote.labels = labels if annotations: new_remote.annotations = annotations + if image_config: + new_remote.image_config = image_config + if raw_output_data_config: + new_remote.raw_output_data_config = raw_output_data_config return new_remote - def fetch_task(self, project: str, domain: str, name: str, version: str = None) -> FlyteTask: + def fetch_task(self, project: str = None, domain: str = None, name: str = None, version: str = None) -> FlyteTask: + if name is None: + raise user_exceptions.FlyteAssertion("the 'name' argument must be specified.") task_id = _get_entity_identifier( - self.client.list_tasks_paginated, ResourceType.TASK, project, domain, name, version + self.client.list_tasks_paginated, + ResourceType.TASK, + project or self.default_project, + domain or self.default_domain, + name, + version or self.version, ) admin_task = self.client.get_task(task_id) flyte_task = FlyteTask.promote_from_model(admin_task.closure.compiled_task.template) flyte_task._id = task_id return flyte_task - def fetch_workflow(self, project: str, domain: str, name: str, version: str = None) -> FlyteWorkflow: + def fetch_workflow( + self, project: str = None, domain: str = None, name: str = None, version: str = None + ) -> FlyteWorkflow: + if name is None: + raise user_exceptions.FlyteAssertion("the 'name' argument must be specified.") workflow_id = _get_entity_identifier( - self.client.list_workflows_paginated, ResourceType.WORKFLOW, project, domain, name, version + self.client.list_workflows_paginated, + ResourceType.WORKFLOW, + project or self.default_project, + domain or self.default_domain, + name, + version or self.version, ) admin_workflow = self.client.get_workflow(workflow_id) compiled_wf = admin_workflow.closure.compiled_workflow @@ -229,9 +237,18 @@ def fetch_workflow(self, project: str, domain: str, name: str, version: str = No flyte_workflow._id = workflow_id return flyte_workflow - def fetch_launch_plan(self, project: str, domain: str, name: str, version: str = None) -> FlyteLaunchPlan: + def fetch_launch_plan( + self, project: str = None, domain: str = None, name: str = None, version: str = None + ) -> FlyteLaunchPlan: + if name is None: + raise user_exceptions.FlyteAssertion("the 'name' argument must be specified.") launch_plan_id = _get_entity_identifier( - self.client.list_launch_plans_paginated, ResourceType.LAUNCH_PLAN, project, domain, name, version + self.client.list_launch_plans_paginated, + ResourceType.LAUNCH_PLAN, + project or self.default_project, + domain or self.default_domain, + name, + version or self.version, ) admin_launch_plan = self.client.get_launch_plan(launch_plan_id) flyte_launch_plan = FlyteLaunchPlan.promote_from_model(admin_launch_plan.spec) @@ -242,9 +259,19 @@ def fetch_launch_plan(self, project: str, domain: str, name: str, version: str = flyte_launch_plan._interface = workflow.interface return flyte_launch_plan - def fetch_workflow_execution(self, project: str, domain: str, name: str) -> FlyteWorkflowExecution: + def fetch_workflow_execution( + self, project: str = None, domain: str = None, name: str = None + ) -> FlyteWorkflowExecution: + if name is None: + raise user_exceptions.FlyteAssertion("the 'name' argument must be specified.") return FlyteWorkflowExecution.promote_from_model( - self.client.get_execution(WorkflowExecutionIdentifier(project, domain, name)) + self.client.get_execution( + WorkflowExecutionIdentifier( + project or self.default_project, + domain or self.default_domain, + name, + ) + ) ) ###################### @@ -252,45 +279,86 @@ def fetch_workflow_execution(self, project: str, domain: str, name: str) -> Flyt ###################### @singledispatchmethod - def _serialize(self, entity: FlyteLocalEntity) -> FlyteControlPlaneEntity: + def _serialize( + self, + entity: FlyteLocalEntity, + project: str = None, + domain: str = None, + version: str = None, + **kwargs, + ) -> FlyteControlPlaneEntity: # TODO: Revisit cache - return get_serializable(self.serialized_entity_cache, self.serialization_settings, entity=entity) + return get_serializable( + self.serialized_entity_cache, + SerializationSettings( + project or self.default_project, + domain or self.default_domain, + version or self.version, + self.image_config, + ), + entity=entity, + ) ##################### # Register Entities # ##################### + def _resolve_identifier_kwargs( + self, + entity, + project: typing.Optional[str], + domain: typing.Optional[str], + name: typing.Optional[str], + version: typing.Optional[str], + ): + """Resolves the identifier attributes based on user input, falling back on .""" + return { + "project": project or self.default_project, + "domain": domain or self.default_domain, + "name": name or entity.name, + "version": version or self.version, + } + @singledispatchmethod - def register(self, entity): + def register(self, entity, project: str = None, domain: str = None, name: str = None, version: str = None): + """Register an entity to flyte admin.""" raise NotImplementedError(f"entity type {type(entity)} not recognized for registration") @register.register - def _(self, entity: PythonTask): + def _(self, entity: PythonTask, project: str = None, domain: str = None, name: str = None, version: str = None): + """Register an @task-decorated function or TaskTemplate task to flyte admin.""" + flyte_id_kwargs = self._resolve_identifier_kwargs(entity, project, domain, name, version) self.client.create_task( - Identifier(ResourceType.TASK, self.project, self.domain, entity.name, self.version), - task_spec=self._serialize(entity), + Identifier(ResourceType.TASK, **flyte_id_kwargs), + task_spec=self._serialize(entity, **flyte_id_kwargs), ) - return self.fetch_task(self.project, self.domain, entity.name, self.version) + return self.fetch_task(**flyte_id_kwargs) @register.register - def _(self, entity: WorkflowBase): + def _(self, entity: WorkflowBase, project: str = None, domain: str = None, name: str = None, version: str = None): + """Register an @workflow-decorated function to flyte admin.""" + flyte_id_kwargs = self._resolve_identifier_kwargs(entity, project, domain, name, version) self.client.create_workflow( - Identifier(ResourceType.WORKFLOW, self.project, self.domain, entity.name, self.version), - workflow_spec=self._serialize(entity), + Identifier(ResourceType.WORKFLOW, **flyte_id_kwargs), + workflow_spec=self._serialize(entity, **flyte_id_kwargs), ) - return self.fetch_workflow(self.project, self.domain, entity.name, self.version) + return self.fetch_workflow(**flyte_id_kwargs) @register.register - def _(self, entity: LaunchPlan): + def _(self, entity: LaunchPlan, project: str = None, domain: str = None, name: str = None, version: str = None): + """Register a LaunchPlan object to flyte admin.""" # See _get_patch_launch_plan_fn for what we need to patch. These are the elements of a launch plan # that are not set at serialization time and are filled in either by flyte-cli register files or flytectl. - serialized_lp: launch_plan_models.LaunchPlan = self._serialize(entity) - serialized_lp.spec._auth_role = common_models.AuthRole( - self.auth_role.assumable_iam_role, self.auth_role.kubernetes_service_account - ) - serialized_lp.spec._raw_output_data_config = common_models.RawOutputDataConfig( - self.raw_output_data_config.output_location_prefix - ) + flyte_id_kwargs = self._resolve_identifier_kwargs(entity, project, domain, name, version) + serialized_lp: launch_plan_models.LaunchPlan = self._serialize(entity, **flyte_id_kwargs) + if self.auth_role: + serialized_lp.spec._auth_role = common_models.AuthRole( + self.auth_role.assumable_iam_role, self.auth_role.kubernetes_service_account + ) + if self.raw_output_data_config: + serialized_lp.spec._raw_output_data_config = common_models.RawOutputDataConfig( + self.raw_output_data_config.output_location_prefix + ) # Patch in labels and annotations if self.labels: @@ -302,10 +370,10 @@ def _(self, entity: LaunchPlan): serialized_lp.spec._annotations.values[k] = v self.client.create_launch_plan( - Identifier(ResourceType.LAUNCH_PLAN, self.project, self.domain, entity.name, self.version), + Identifier(ResourceType.LAUNCH_PLAN, **flyte_id_kwargs), launch_plan_spec=serialized_lp.spec, ) - return self.fetch_launch_plan(self.project, self.domain, entity.name, self.version) + return self.fetch_launch_plan(**flyte_id_kwargs) #################### # Execute Entities # @@ -356,8 +424,16 @@ def _execute( @singledispatchmethod def execute( - self, entity, inputs: typing.Dict[str, typing.Any], execution_name=None, wait=False + self, + entity, + inputs: typing.Dict[str, typing.Any], + execution_name=None, + project: str = None, + domain: str = None, + version: str = None, + wait=False, ) -> FlyteWorkflowExecution: + """Execute a task, workflow, or launchplan.""" raise NotImplementedError(f"entity type {type(entity)} not recognized for execution") # Flyte Remote Entities @@ -366,52 +442,105 @@ def execute( @execute.register(FlyteTask) @execute.register(FlyteLaunchPlan) def _( - self, entity, inputs: typing.Dict[str, typing.Any], execution_name=None, wait=False + self, + entity, + inputs: typing.Dict[str, typing.Any], + project: str = None, + domain: str = None, + name: str = None, + version: str = None, + execution_name=None, + wait=False, ) -> FlyteWorkflowExecution: + """Execute a FlyteTask, or FlyteLaunchplan.""" return self._execute(entity.id, inputs, execution_name, wait) @execute.register def _( - self, entity: FlyteWorkflow, inputs: typing.Dict[str, typing.Any], execution_name=None, wait=False + self, + entity: FlyteWorkflow, + inputs: typing.Dict[str, typing.Any], + project: str = None, + domain: str = None, + name: str = None, + version: str = None, + execution_name=None, + wait=False, ) -> FlyteWorkflowExecution: - launch_plan = self.fetch_launch_plan(self.project, self.domain, entity.id.name, self.version) - return self.execute(launch_plan, inputs, execution_name, wait) + """Execute a FlyteWorkflow.""" + return self.execute( + self.fetch_launch_plan(entity.id.project, entity.id.domain, entity.id.name, entity.id.version), + inputs, + execution_name=execution_name, + wait=wait, + ) # Flytekit Entities # ----------------- @execute.register def _( - self, entity: PythonTask, inputs: typing.Dict[str, typing.Any], execution_name: str = None, wait=False + self, + entity: PythonTask, + inputs: typing.Dict[str, typing.Any], + project: str = None, + domain: str = None, + name: str = None, + version: str = None, + execution_name: str = None, + wait=False, ) -> FlyteWorkflowExecution: + """Execute an @task-decorated function or TaskTemplate task.""" + flyte_id_kwargs = self._resolve_identifier_kwargs(entity, project, domain, name, version) try: - flyte_task: FlyteTask = self.fetch_task(self.project, self.domain, entity.name, self.version) + flyte_task: FlyteTask = self.fetch_task(**flyte_id_kwargs) except Exception: - # TODO: fast register the task if fast=True - flyte_task: FlyteTask = self.register(entity) - return self.execute(flyte_task, inputs, execution_name, wait) + flyte_task: FlyteTask = self.register(entity, **flyte_id_kwargs) + return self.execute(flyte_task, inputs, execution_name=execution_name, wait=wait) @execute.register def _( - self, entity: WorkflowBase, inputs: typing.Dict[str, typing.Any], execution_name=None, wait=False + self, + entity: WorkflowBase, + inputs: typing.Dict[str, typing.Any], + project: str = None, + domain: str = None, + name: str = None, + version: str = None, + execution_name=None, + wait=False, ) -> FlyteWorkflowExecution: + """Execute an @workflow-decorated function.""" + flyte_id_kwargs = self._resolve_identifier_kwargs(entity, project, domain, name, version) try: - flyte_workflow: FlyteWorkflow = self.fetch_workflow(self.project, self.domain, entity.name, self.version) + flyte_workflow: FlyteWorkflow = self.fetch_workflow(**flyte_id_kwargs) except Exception: - flyte_workflow: FlyteWorkflow = self.register(entity) - return self.execute(flyte_workflow, inputs, execution_name, wait) + flyte_workflow: FlyteWorkflow = self.register(entity, **flyte_id_kwargs) + return self.execute(flyte_workflow, inputs, execution_name=execution_name, wait=wait) @execute.register def _( - self, entity: LaunchPlan, inputs: typing.Dict[str, typing.Any], execution_name=None, wait=False + self, + entity: LaunchPlan, + inputs: typing.Dict[str, typing.Any], + project: str = None, + domain: str = None, + name: str = None, + version: str = None, + execution_name=None, + wait=False, ) -> FlyteWorkflowExecution: + """Execute a LaunchPlan object.""" + flyte_id_kwargs = self._resolve_identifier_kwargs(entity, project, domain, name, version) try: - flyte_launchplan: FlyteLaunchPlan = self.fetch_launch_plan( - self.project, self.domain, entity.name, self.version - ) + flyte_launchplan: FlyteLaunchPlan = self.fetch_launch_plan(**flyte_id_kwargs) except Exception: - flyte_launchplan: FlyteLaunchPlan = self.register(entity) - return self.execute(flyte_launchplan, inputs, execution_name, wait) + flyte_launchplan: FlyteLaunchPlan = self.register(entity, **flyte_id_kwargs) + return self.execute(flyte_launchplan, inputs, execution_name=execution_name, wait=wait) + + ################################### + # Wait for Executions to Complete # + ################################### @singledispatchmethod def wait( @@ -420,6 +549,7 @@ def wait( timeout: typing.Optional[timedelta] = None, poll_interval: typing.Optional[timedelta] = None, ): + """Wait for an execution to finish.""" raise NotImplementedError(f"Execution type {type(execution)} cannot be waited upon.") @wait.register @@ -429,6 +559,7 @@ def _( timeout: typing.Optional[timedelta] = None, poll_interval: typing.Optional[timedelta] = None, ): + """Wait for a FlyteWorkflowExecution to finish.""" poll_interval = poll_interval or timedelta(seconds=30) time_to_give_up = datetime.max if timeout is None else datetime.utcnow() + timeout @@ -440,6 +571,10 @@ def _( raise user_exceptions.FlyteTimeout(f"Execution {self} did not complete before timeout.") + ######################## + # Sync Execution State # + ######################## + @singledispatchmethod def sync(self, execution): """Sync a flyte execution object with its corresponding remote state.""" diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index c43ce8de207..2b9ed353f0a 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -32,19 +32,15 @@ def test_client(flyteclient, flyte_workflows_register): def test_fetch_execute_launch_plan(flyteclient, flyte_workflows_register): - remote = FlyteRemote.from_environment(PROJECT, "development", f"v{VERSION}") - flyte_launch_plan = remote.fetch_launch_plan( - PROJECT, "development", "workflows.basic.hello_world.my_wf", f"v{VERSION}" - ) + remote = FlyteRemote.from_environment(PROJECT, "development") + flyte_launch_plan = remote.fetch_launch_plan(name="workflows.basic.hello_world.my_wf", version=f"v{VERSION}") execution = remote.execute(flyte_launch_plan, {}, wait=True) assert execution.outputs["o0"] == "hello world" def fetch_execute_launch_plan_with_args(flyteclient, flyte_workflows_register): - remote = FlyteRemote.from_environment(PROJECT, "development", f"v{VERSION}") - flyte_launch_plan = remote.fetch_launch_plan( - PROJECT, "development", "workflows.basic.basic_workflow.my_wf", f"v{VERSION}" - ) + remote = FlyteRemote.from_environment(PROJECT, "development") + flyte_launch_plan = remote.fetch_launch_plan(name="workflows.basic.basic_workflow.my_wf", version=f"v{VERSION}") execution = remote.execute(flyte_launch_plan, {"a": 10, "b": "foobar"}, wait=True) assert execution.node_executions["n0"].inputs == {"a": 10} assert execution.node_executions["n0"].outputs == {"t1_int_output": 12, "c": "world"} @@ -61,10 +57,8 @@ def fetch_execute_launch_plan_with_args(flyteclient, flyte_workflows_register): def test_monitor_workflow_execution(flyteclient, flyte_workflows_register): - remote = FlyteRemote.from_environment(PROJECT, "development", f"v{VERSION}") - flyte_launch_plan = remote.fetch_launch_plan( - PROJECT, "development", "workflows.basic.hello_world.my_wf", f"v{VERSION}" - ) + remote = FlyteRemote.from_environment(PROJECT, "development") + flyte_launch_plan = remote.fetch_launch_plan(name="workflows.basic.hello_world.my_wf", version=f"v{VERSION}") execution = remote.execute(flyte_launch_plan, {}) poll_interval = datetime.timedelta(seconds=1) @@ -99,10 +93,8 @@ def test_monitor_workflow_execution(flyteclient, flyte_workflows_register): def test_fetch_execute_launch_plan_with_subworkflows(flyteclient, flyte_workflows_register): - remote = FlyteRemote.from_environment(PROJECT, "development", f"v{VERSION}") - flyte_launch_plan = remote.fetch_launch_plan( - PROJECT, "development", "workflows.basic.subworkflows.parent_wf", f"v{VERSION}" - ) + remote = FlyteRemote.from_environment(PROJECT, "development") + flyte_launch_plan = remote.fetch_launch_plan(name="workflows.basic.subworkflows.parent_wf", version=f"v{VERSION}") execution = remote.execute(flyte_launch_plan, {"a": 101}, wait=True) # check node execution inputs and outputs assert execution.node_executions["n0"].inputs == {"a": 101} @@ -117,15 +109,15 @@ def test_fetch_execute_launch_plan_with_subworkflows(flyteclient, flyte_workflow def test_fetch_execute_workflow(flyteclient, flyte_workflows_register): - remote = FlyteRemote.from_environment(PROJECT, "development", f"v{VERSION}") - flyte_workflow = remote.fetch_workflow(PROJECT, "development", "workflows.basic.hello_world.my_wf", f"v{VERSION}") + remote = FlyteRemote.from_environment(PROJECT, "development") + flyte_workflow = remote.fetch_workflow(name="workflows.basic.hello_world.my_wf", version=f"v{VERSION}") execution = remote.execute(flyte_workflow, {}, wait=True) assert execution.outputs["o0"] == "hello world" def test_fetch_execute_task(flyteclient, flyte_workflows_register): - remote = FlyteRemote.from_environment(PROJECT, "development", f"v{VERSION}") - flyte_task = remote.fetch_task(PROJECT, "development", "workflows.basic.basic_workflow.t1", f"v{VERSION}") + remote = FlyteRemote.from_environment(PROJECT, "development") + flyte_task = remote.fetch_task(name="workflows.basic.basic_workflow.t1", version=f"v{VERSION}") execution = remote.execute(flyte_task, {"a": 10}, wait=True) assert execution.outputs["t1_int_output"] == 12 assert execution.outputs["c"] == "world" @@ -146,8 +138,8 @@ def test_execute_python_task(flyteclient, flyte_workflows_register): t1._name = t1.name.replace("mock_flyte_repo.", "") _set_env() - remote = FlyteRemote.from_environment(PROJECT, "development", f"v{VERSION}") - execution = remote.execute(t1, inputs={"a": 10}, wait=True) + remote = FlyteRemote.from_environment(PROJECT, "development") + execution = remote.execute(t1, inputs={"a": 10}, version=f"v{VERSION}", wait=True) assert execution.outputs["t1_int_output"] == 12 assert execution.outputs["c"] == "world" @@ -160,12 +152,12 @@ def test_execute_python_workflow_and_launch_plan(flyteclient, flyte_workflows_re my_wf._name = my_wf.name.replace("mock_flyte_repo.", "") _set_env() - remote = FlyteRemote.from_environment(PROJECT, "development", f"v{VERSION}") - execution = remote.execute(my_wf, inputs={"a": 10, "b": "xyz"}, wait=True) + remote = FlyteRemote.from_environment(PROJECT, "development") + execution = remote.execute(my_wf, inputs={"a": 10, "b": "xyz"}, version=f"v{VERSION}", wait=True) assert execution.outputs["o0"] == 12 assert execution.outputs["o1"] == "xyzworld" launch_plan = LaunchPlan.get_or_create(workflow=my_wf, name=my_wf.name) - execution = remote.execute(launch_plan, inputs={"a": 14, "b": "foobar"}, wait=True) + execution = remote.execute(launch_plan, inputs={"a": 14, "b": "foobar"}, version=f"v{VERSION}", wait=True) assert execution.outputs["o0"] == 16 assert execution.outputs["o1"] == "foobarworld" diff --git a/tests/flytekit/unit/remote/test_remote.py b/tests/flytekit/unit/remote/test_remote.py index 4df9b896ad5..6b8aa233b4b 100644 --- a/tests/flytekit/unit/remote/test_remote.py +++ b/tests/flytekit/unit/remote/test_remote.py @@ -56,7 +56,7 @@ def test_remote_fetch_execute_entities_task_workflow_launchplan( getattr(mock_client, CLIENT_METHODS[resource_type]).return_value = admin_entities, "" mock_client_manager.return_value.client = mock_client - remote = FlyteRemote() + remote = FlyteRemote.from_environment() fetch_method = getattr(remote, REMOTE_METHODS[resource_type]) flyte_entity_latest = fetch_method("p1", "d1", "n1", "latest") flyte_entity_latest_implicit = fetch_method("p1", "d1", "n1") @@ -81,6 +81,6 @@ def test_remote_fetch_workflow_execution(mock_url, mock_client_manager): mock_client.get_execution.return_value = admin_workflow_execution mock_client_manager.return_value.client = mock_client - remote = FlyteRemote() + remote = FlyteRemote.from_environment() flyte_workflow_execution = remote.fetch_workflow_execution("p1", "d1", "n1") assert flyte_workflow_execution.id == admin_workflow_execution.id