Skip to content

Commit

Permalink
Remove sync singledispatch, add option for top-level only sync (#681)
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor authored Oct 7, 2021
1 parent 243adb7 commit 1fb425f
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 27 deletions.
57 changes: 32 additions & 25 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,7 @@ def _(

def wait(
self,
execution: typing.Union[FlyteWorkflowExecution, FlyteNodeExecution, FlyteTaskExecution],
execution: FlyteWorkflowExecution,
timeout: typing.Optional[timedelta] = None,
poll_interval: typing.Optional[timedelta] = None,
):
Expand All @@ -941,7 +941,7 @@ def wait(
time_to_give_up = datetime.max if timeout is None else datetime.utcnow() + timeout

while datetime.utcnow() < time_to_give_up:
execution = self.sync(execution)
execution = self.sync_workflow_execution(execution)
if execution.is_complete:
return execution
time.sleep(poll_interval.total_seconds())
Expand All @@ -952,24 +952,32 @@ def wait(
# Sync Execution State #
########################

@singledispatchmethod
def sync(
self,
execution: typing.Union[FlyteWorkflowExecution, FlyteNodeExecution, FlyteTaskExecution],
execution: FlyteWorkflowExecution,
entity_definition: typing.Union[FlyteWorkflow, FlyteTask] = None,
):
"""Sync a flyte execution object with its corresponding remote state.
This method syncs the inputs and outputs of the execution object and all of its child node executions.
:param execution: workflow execution to sync.
:param entity_definition: optional, reference entity definition which adds more context to this execution entity
sync_nodes: bool = True,
) -> FlyteWorkflowExecution:
"""
This function was previously a singledispatchmethod. We've removed that but this function remains
so that we don't break people.
:param execution:
:param entity_definition:
:param sync_nodes: By default sync will fetch data on all underlying node executions (recursively,
so subworkflows will also get picked up). Set this to False in order to prevent that (which
will make this call faster).
:return: Returns the same execution object, but with additional information pulled in.
"""
raise NotImplementedError(f"Execution type {type(execution)} cannot be synced.")
if not isinstance(execution, FlyteWorkflowExecution):
raise ValueError(f"remote.sync should only be called on workflow executions, got {type(execution)}")
return self.sync_workflow_execution(execution, entity_definition, sync_nodes)

@sync.register
def _(
self, execution: FlyteWorkflowExecution, entity_definition: typing.Union[FlyteWorkflow, FlyteTask] = None
def sync_workflow_execution(
self,
execution: FlyteWorkflowExecution,
entity_definition: typing.Union[FlyteWorkflow, FlyteTask] = None,
sync_nodes: bool = True,
) -> FlyteWorkflowExecution:

"""Sync a FlyteWorkflowExecution object with its corresponding remote state."""
Expand All @@ -988,14 +996,14 @@ def _(

# sync closure, node executions, and inputs/outputs
execution._closure = self.client.get_execution(execution.id).closure
execution._node_executions = {
node.id.node_id: self.sync(FlyteNodeExecution.promote_from_model(node), flyte_entity)
for node in iterate_node_executions(self.client, execution.id)
}
if sync_nodes:
execution._node_executions = {
node.id.node_id: self.sync_node_execution(FlyteNodeExecution.promote_from_model(node), flyte_entity)
for node in iterate_node_executions(self.client, execution.id)
}
return self._assign_inputs_and_outputs(execution, execution_data, flyte_entity.interface)

@sync.register
def _(
def sync_node_execution(
self, execution: FlyteNodeExecution, entity_definition: typing.Union[FlyteWorkflow, FlyteTask] = None
) -> FlyteNodeExecution:
"""Sync a FlyteNodeExecution object with its corresponding remote state."""
Expand All @@ -1010,7 +1018,7 @@ def _(
execution._closure = self.client.get_node_execution(execution.id).closure
if execution.metadata.is_parent_node:
execution._subworkflow_node_executions = [
self.sync(FlyteNodeExecution.promote_from_model(node), entity_definition)
self.sync_node_execution(FlyteNodeExecution.promote_from_model(node), entity_definition)
for node in iterate_node_executions(
self.client,
workflow_execution_identifier=execution.id.execution_id,
Expand All @@ -1019,7 +1027,7 @@ def _(
]
else:
execution._task_executions = [
self.sync(FlyteTaskExecution.promote_from_model(t))
self.sync_task_execution(FlyteTaskExecution.promote_from_model(t))
for t in iterate_task_executions(self.client, execution.id)
]
execution._interface = self._get_node_execution_interface(execution, entity_definition)
Expand All @@ -1029,8 +1037,7 @@ def _(
execution.interface,
)

@sync.register
def _(
def sync_task_execution(
self, execution: FlyteTaskExecution, entity_definition: typing.Union[FlyteWorkflow, FlyteTask] = None
) -> FlyteTaskExecution:
"""Sync a FlyteTaskExecution object with its corresponding remote state."""
Expand Down
4 changes: 2 additions & 2 deletions tests/flytekit/integration/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_monitor_workflow_execution(flyteclient, flyte_workflows_register, flyte
poll_interval = datetime.timedelta(seconds=1)
time_to_give_up = datetime.datetime.utcnow() + datetime.timedelta(seconds=60)

execution = remote.sync(execution)
execution = remote.sync_workflow_execution(execution)
while datetime.datetime.utcnow() < time_to_give_up:

if execution.is_complete:
Expand All @@ -94,7 +94,7 @@ def test_monitor_workflow_execution(flyteclient, flyte_workflows_register, flyte
execution.outputs

time.sleep(poll_interval.total_seconds())
execution = remote.sync(execution)
execution = remote.sync_workflow_execution(execution)

if execution.node_executions:
assert execution.node_executions["start-node"].closure.phase == 3 # SUCCEEEDED
Expand Down

0 comments on commit 1fb425f

Please sign in to comment.