diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index b7cd1bdb8b..cbf24615e9 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -927,7 +927,7 @@ def _( def wait( self, - execution: typing.Union[FlyteWorkflowExecution, FlyteNodeExecution, FlyteTaskExecution], + execution: FlyteWorkflowExecution, timeout: typing.Optional[timedelta] = None, poll_interval: typing.Optional[timedelta] = None, ): @@ -941,7 +941,7 @@ def wait( time_to_give_up = datetime.max if timeout is None else datetime.utcnow() + timeout while datetime.utcnow() < time_to_give_up: - execution = self.sync(execution) + execution = self.sync_workflow_execution(execution) if execution.is_complete: return execution time.sleep(poll_interval.total_seconds()) @@ -952,24 +952,32 @@ def wait( # Sync Execution State # ######################## - @singledispatchmethod def sync( self, - execution: typing.Union[FlyteWorkflowExecution, FlyteNodeExecution, FlyteTaskExecution], + execution: FlyteWorkflowExecution, entity_definition: typing.Union[FlyteWorkflow, FlyteTask] = None, - ): - """Sync a flyte execution object with its corresponding remote state. - - This method syncs the inputs and outputs of the execution object and all of its child node executions. - - :param execution: workflow execution to sync. - :param entity_definition: optional, reference entity definition which adds more context to this execution entity + sync_nodes: bool = True, + ) -> FlyteWorkflowExecution: + """ + This function was previously a singledispatchmethod. We've removed that but this function remains + so that we don't break people. + + :param execution: + :param entity_definition: + :param sync_nodes: By default sync will fetch data on all underlying node executions (recursively, + so subworkflows will also get picked up). Set this to False in order to prevent that (which + will make this call faster). + :return: Returns the same execution object, but with additional information pulled in. """ - raise NotImplementedError(f"Execution type {type(execution)} cannot be synced.") + if not isinstance(execution, FlyteWorkflowExecution): + raise ValueError(f"remote.sync should only be called on workflow executions, got {type(execution)}") + return self.sync_workflow_execution(execution, entity_definition, sync_nodes) - @sync.register - def _( - self, execution: FlyteWorkflowExecution, entity_definition: typing.Union[FlyteWorkflow, FlyteTask] = None + def sync_workflow_execution( + self, + execution: FlyteWorkflowExecution, + entity_definition: typing.Union[FlyteWorkflow, FlyteTask] = None, + sync_nodes: bool = True, ) -> FlyteWorkflowExecution: """Sync a FlyteWorkflowExecution object with its corresponding remote state.""" @@ -988,14 +996,14 @@ def _( # sync closure, node executions, and inputs/outputs execution._closure = self.client.get_execution(execution.id).closure - execution._node_executions = { - node.id.node_id: self.sync(FlyteNodeExecution.promote_from_model(node), flyte_entity) - for node in iterate_node_executions(self.client, execution.id) - } + if sync_nodes: + execution._node_executions = { + node.id.node_id: self.sync_node_execution(FlyteNodeExecution.promote_from_model(node), flyte_entity) + for node in iterate_node_executions(self.client, execution.id) + } return self._assign_inputs_and_outputs(execution, execution_data, flyte_entity.interface) - @sync.register - def _( + def sync_node_execution( self, execution: FlyteNodeExecution, entity_definition: typing.Union[FlyteWorkflow, FlyteTask] = None ) -> FlyteNodeExecution: """Sync a FlyteNodeExecution object with its corresponding remote state.""" @@ -1010,7 +1018,7 @@ def _( execution._closure = self.client.get_node_execution(execution.id).closure if execution.metadata.is_parent_node: execution._subworkflow_node_executions = [ - self.sync(FlyteNodeExecution.promote_from_model(node), entity_definition) + self.sync_node_execution(FlyteNodeExecution.promote_from_model(node), entity_definition) for node in iterate_node_executions( self.client, workflow_execution_identifier=execution.id.execution_id, @@ -1019,7 +1027,7 @@ def _( ] else: execution._task_executions = [ - self.sync(FlyteTaskExecution.promote_from_model(t)) + self.sync_task_execution(FlyteTaskExecution.promote_from_model(t)) for t in iterate_task_executions(self.client, execution.id) ] execution._interface = self._get_node_execution_interface(execution, entity_definition) @@ -1029,8 +1037,7 @@ def _( execution.interface, ) - @sync.register - def _( + def sync_task_execution( self, execution: FlyteTaskExecution, entity_definition: typing.Union[FlyteWorkflow, FlyteTask] = None ) -> FlyteTaskExecution: """Sync a FlyteTaskExecution object with its corresponding remote state.""" diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index d0af4933f3..2b61f220c6 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -82,7 +82,7 @@ def test_monitor_workflow_execution(flyteclient, flyte_workflows_register, flyte poll_interval = datetime.timedelta(seconds=1) time_to_give_up = datetime.datetime.utcnow() + datetime.timedelta(seconds=60) - execution = remote.sync(execution) + execution = remote.sync_workflow_execution(execution) while datetime.datetime.utcnow() < time_to_give_up: if execution.is_complete: @@ -94,7 +94,7 @@ def test_monitor_workflow_execution(flyteclient, flyte_workflows_register, flyte execution.outputs time.sleep(poll_interval.total_seconds()) - execution = remote.sync(execution) + execution = remote.sync_workflow_execution(execution) if execution.node_executions: assert execution.node_executions["start-node"].closure.phase == 3 # SUCCEEEDED