diff --git a/src/deadline_worker_agent/sessions/session.py b/src/deadline_worker_agent/sessions/session.py index 506c37e2..0f871528 100644 --- a/src/deadline_worker_agent/sessions/session.py +++ b/src/deadline_worker_agent/sessions/session.py @@ -1018,7 +1018,7 @@ def _action_updated_impl( *, action_status: ActionStatus, now: datetime, - ) -> None: + ) -> Optional[Future]: """Internal implementation for the callback invoked on every Open Job Description status/progress update and the completion/exit of the current action. The caller should acquire the Session._current_action_lock before calling this method. @@ -1097,7 +1097,9 @@ def _action_updated_impl( current_action=current_action, ) future.add_done_callback(on_done_with_sync_asset_outputs) - + # Returning the future just to make this method easier to test. + # Tests need to wait on the future to avoid race conditions + return future else: self._handle_action_update(is_unsuccessful, action_status, current_action, now) diff --git a/test/unit/sessions/test_session.py b/test/unit/sessions/test_session.py index 5ba445a6..310df65a 100644 --- a/test/unit/sessions/test_session.py +++ b/test/unit/sessions/test_session.py @@ -1,6 +1,7 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. from __future__ import annotations +from concurrent.futures import wait from datetime import datetime, timedelta from pathlib import PurePosixPath, PureWindowsPath from threading import Event, RLock @@ -1343,10 +1344,12 @@ def test_failed_enter_env( with patch.object(session, "_sync_asset_outputs") as mock_sync_asset_outputs: # WHEN - session._action_updated_impl( + future = session._action_updated_impl( action_status=failed_action_status, now=action_complete_time, ) + if future: + wait([future]) # THEN mock_report_action_update.assert_called_once_with(expected_action_update) @@ -1410,10 +1413,12 @@ def test_failed_task_run( with patch.object(session, "_sync_asset_outputs") as mock_sync_asset_outputs: # WHEN - session._action_updated_impl( + future = session._action_updated_impl( action_status=failed_action_status, now=action_complete_time, ) + if future: + wait([future]) # THEN mock_report_action_update.assert_called_once_with(expected_action_update) @@ -1488,10 +1493,12 @@ def sync_asset_outputs_side_effect(*, current_action: CurrentAction) -> None: mock_sync_asset_outputs.side_effect = sync_asset_outputs_side_effect # WHEN - session._action_updated_impl( + future = session._action_updated_impl( action_status=success_action_status, now=action_complete_time, ) + if future: + wait([future]) # THEN mock_report_action_update.assert_called_once_with(expected_action_update) @@ -1565,10 +1572,12 @@ def mock_now(*arg, **kwarg) -> datetime: mock_datetime.now.side_effect = mock_now # WHEN - session._action_updated_impl( + future = session._action_updated_impl( action_status=success_action_status, now=action_complete_time, ) + if future: + wait([future]) # THEN mock_report_action_update.assert_called_once_with(expected_action_update) @@ -1589,14 +1598,12 @@ def test_logs_succeeded( ) -> None: """Tests that succeeded actions are logged""" # WHEN - session._action_updated_impl( + future = session._action_updated_impl( action_status=success_action_status, now=action_complete_time, ) - # This because the _action_update_impl submits a future to this thread pool executor - # The test assertion depends on this future completing and so there's a race condition - # if we do not wait for the thread pool to shutdown and all futures to complete. - session._executor.shutdown() + if future: + wait([future]) # THEN mock_mod_logger.info.assert_called_once() @@ -1618,10 +1625,12 @@ def test_logs_failed( ) -> None: """Tests that failed actions are logged""" # WHEN - session._action_updated_impl( + future = session._action_updated_impl( action_status=failed_action_status, now=action_complete_time, ) + if future: + wait([future]) # THEN mock_mod_logger.info.assert_called_once() @@ -1642,10 +1651,12 @@ def test_logs_canceled( ) -> None: """Tests that canceled actions are logged""" # WHEN - session._action_updated_impl( + future = session._action_updated_impl( action_status=canceled_action_status, now=action_complete_time, ) + if future: + wait([future]) # THEN mock_mod_logger.info.assert_called_once()