From 033c9e621224f8b46793928ce48fdf2fec31c841 Mon Sep 17 00:00:00 2001 From: Ketan Umare Date: Thu, 24 Oct 2024 13:07:53 -0700 Subject: [PATCH] updated Signed-off-by: Ketan Umare --- flytekit/remote/executions.py | 8 +++----- flytekit/remote/remote.py | 24 ++++++++++++++++++------ 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/flytekit/remote/executions.py b/flytekit/remote/executions.py index 5e2fabbddb..1f2c8f4569 100644 --- a/flytekit/remote/executions.py +++ b/flytekit/remote/executions.py @@ -166,17 +166,15 @@ def wait( ) -> "FlyteWorkflowExecution": """ Wait for the execution to complete. This is a blocking call. + :param timeout: The maximum amount of time to wait for the execution to complete. It can be a timedelta or - a duration in seconds as int. + a duration in seconds as int. :param poll_interval: The amount of time to wait between polling the state of the execution. It can be a timedelta or a duration in seconds as int. + :param sync_nodes: Whether to sync the state of the nodes as well. """ if self._remote is None: raise user_exceptions.FlyteAssertion("Cannot wait without a remote") - if poll_interval is not None and not isinstance(poll_interval, timedelta): - poll_interval = timedelta(seconds=poll_interval) - if timeout is not None and not isinstance(timeout, timedelta): - timeout = timedelta(seconds=timeout) return self._remote.wait(self, timeout=timeout, poll_interval=poll_interval, sync_nodes=sync_nodes) def _repr_html_(self) -> str: diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 110ffc485f..02808ddef6 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -245,7 +245,7 @@ def __init__( default_project: typing.Optional[str] = None, default_domain: typing.Optional[str] = None, data_upload_location: str = "flyte://my-s3-bucket/", - interactive_mode_enabled: bool = False, + interactive_mode_enabled: typing.Optional[bool] = None, **kwargs, ): """Initialize a FlyteRemote object. @@ -256,11 +256,16 @@ def __init__( :param default_domain: default domain to use when fetching or executing flyte entities. :param data_upload_location: this is where all the default data will be uploaded when providing inputs. The default location - `s3://my-s3-bucket/data` works for sandbox/demo environment. Please override this for non-sandbox cases. - :param interactive_mode_enabled: If set to True, the FlyteRemote will pickle the task/workflow. + :param interactive_mode_enabled: If set to True, the FlyteRemote will pickle the task/workflow, if False, + it will not. If set to None, then it will automatically detect if it is running in an interactive environment + like a Jupyter notebook and enable interactive mode. """ if config is None or config.platform is None or config.platform.endpoint is None: raise user_exceptions.FlyteAssertion("Flyte endpoint should be provided.") + if interactive_mode_enabled is None: + interactive_mode_enabled = ipython_check() + if interactive_mode_enabled is True: logger.warning("Jupyter notebook and interactive task support is still alpha.") @@ -2106,18 +2111,25 @@ def execute_local_launch_plan( def wait( self, execution: FlyteWorkflowExecution, - timeout: typing.Optional[timedelta] = None, - poll_interval: typing.Optional[timedelta] = None, + timeout: typing.Optional[typing.Union[timedelta, int]] = None, + poll_interval: typing.Optional[typing.Union[timedelta, int]] = None, sync_nodes: bool = True, ) -> FlyteWorkflowExecution: """Wait for an execution to finish. :param execution: execution object to wait on - :param timeout: maximum amount of time to wait - :param poll_interval: sync workflow execution at this interval + :param timeout: maximum amount of time to wait. It can be a timedelta or a + duration in seconds as int. + :param poll_interval: sync workflow execution at this interval. It can be a + timedelta or a duration in seconds as int. :param sync_nodes: passed along to the sync call for the workflow execution """ + if poll_interval is not None and not isinstance(poll_interval, timedelta): + poll_interval = timedelta(seconds=poll_interval) poll_interval = poll_interval or timedelta(seconds=30) + + if timeout is not None and not isinstance(timeout, timedelta): + timeout = timedelta(seconds=timeout) time_to_give_up = datetime.max if timeout is None else datetime.now() + timeout while datetime.now() < time_to_give_up: