diff --git a/dashboard/modules/job/job_manager.py b/dashboard/modules/job/job_manager.py index 44bae5d76ba0..fbe608507a46 100644 --- a/dashboard/modules/job/job_manager.py +++ b/dashboard/modules/job/job_manager.py @@ -3,6 +3,7 @@ import json import logging import os +import psutil import random import signal import string @@ -325,6 +326,24 @@ async def _polling(self, child_process: subprocess.Popen) -> int: # still running, yield control, 0.1s by default await asyncio.sleep(self.SUBPROCESS_POLL_PERIOD_S) + async def _poll_all(self, processes: List[psutil.Process]): + """Poll processes until all are completed.""" + while True: + (_, alive) = psutil.wait_procs(processes, timeout=0) + if len(alive) == 0: + return + else: + await asyncio.sleep(self.SUBPROCESS_POLL_PERIOD_S) + + def _kill_processes(self, processes: List[psutil.Process], sig: signal.Signals): + """Ensure each process is already finished or send a kill signal.""" + for proc in processes: + try: + os.kill(proc.pid, sig) + except ProcessLookupError: + # Process is already dead + pass + async def run( self, # Signal actor used in testing to capture PENDING -> RUNNING cases @@ -374,6 +393,7 @@ async def run( ) log_path = self._log_client.get_log_file_path(self._job_id) child_process = self._exec_entrypoint(log_path) + child_pid = child_process.pid polling_task = create_task(self._polling(child_process)) finished, _ = await asyncio.wait( @@ -381,53 +401,45 @@ async def run( ) if self._stop_event.is_set(): + polling_task.cancel() if sys.platform == "win32" and self._win32_job_object: - polling_task.cancel() win32job.TerminateJobObject(self._win32_job_object, -1) elif sys.platform != "win32": + stop_signal = os.environ.get("RAY_JOB_STOP_SIGNAL", "SIGTERM") + if stop_signal not in self.VALID_STOP_SIGNALS: + logger.warning( + f"{stop_signal} not a valid stop signal. Terminating " + "job with SIGTERM." + ) + stop_signal = "SIGTERM" + + job_process = psutil.Process(child_pid) + proc_to_kill = [job_process] + job_process.children(recursive=True) + + # Send stop signal and wait for job to terminate gracefully, + # otherwise SIGKILL job forcefully after timeout. + self._kill_processes(proc_to_kill, getattr(signal, stop_signal)) try: - stop_signal = os.environ.get("RAY_JOB_STOP_SIGNAL", "SIGTERM") - if stop_signal not in self.VALID_STOP_SIGNALS: - logger.warning( - f"{stop_signal} not a valid stop signal. Terminating " - "job with SIGTERM." + stop_job_wait_time = int( + os.environ.get( + "RAY_JOB_STOP_WAIT_TIME_S", + self.DEFAULT_RAY_JOB_STOP_WAIT_TIME_S, ) - stop_signal = "SIGTERM" - os.killpg( - os.getpgid(child_process.pid), - getattr(signal, stop_signal), ) - except ProcessLookupError: - # Process already completed. + poll_job_stop_task = create_task(self._poll_all(proc_to_kill)) + await asyncio.wait_for(poll_job_stop_task, stop_job_wait_time) logger.info( - f"Job {self._job_id} completed on its own before it could " - "be manually terminated." + f"Job {self._job_id} has been terminated gracefully " + f"with {stop_signal}." ) - pass - else: - # Wait for job to exit gracefully, otherwise kill process - # forcefully after timeout. - try: - stop_job_wait_time = int( - os.environ.get( - "RAY_JOB_STOP_WAIT_TIME_S", - self.DEFAULT_RAY_JOB_STOP_WAIT_TIME_S, - ) - ) - await asyncio.wait_for(polling_task, stop_job_wait_time) - logger.info( - f"Job {self._job_id} has been terminated gracefully " - f"with {stop_signal}." - ) - except asyncio.TimeoutError: - logger.warning( - f"Attempt to gracefully terminate job {self._job_id} " - f"through {stop_signal} has timed out after " - f"{stop_job_wait_time} seconds. Job is now being " - "force-killed." - ) - polling_task.cancel() - child_process.kill() + except asyncio.TimeoutError: + logger.warning( + f"Attempt to gracefully terminate job {self._job_id} " + f"through {stop_signal} has timed out after " + f"{stop_job_wait_time} seconds. Job is now being " + "force-killed with SIGKILL." + ) + self._kill_processes(proc_to_kill, signal.SIGKILL) await self._job_info_client.put_status(self._job_id, JobStatus.STOPPED) else: # Child process finished execution and no stop event is set diff --git a/dashboard/modules/job/tests/test_job_manager.py b/dashboard/modules/job/tests/test_job_manager.py index e85a078035e9..2219175c6b36 100644 --- a/dashboard/modules/job/tests/test_job_manager.py +++ b/dashboard/modules/job/tests/test_job_manager.py @@ -774,7 +774,10 @@ async def test_stopped_job(self, job_manager): job_manager.stop_job(job_id) async for lines in job_manager.tail_job_logs(job_id): - assert all(s == "Waiting..." for s in lines.strip().split("\n")) + assert all( + s == "Waiting..." or s == "Terminated" + for s in lines.strip().split("\n") + ) print(lines, end="") await async_wait_for_condition_async_predicate( @@ -850,6 +853,14 @@ def handler(*args): assert job_manager.stop_job(job_id) is True + with pytest.raises(RuntimeError): + await async_wait_for_condition_async_predicate( + check_job_stopped, + job_manager=job_manager, + job_id=job_id, + timeout=stop_timeout - 1, + ) + await async_wait_for_condition( lambda: "SIGTERM signal handled!" in job_manager.get_job_logs(job_id) ) @@ -858,7 +869,7 @@ def handler(*args): check_job_stopped, job_manager=job_manager, job_id=job_id, - timeout=stop_timeout + 10, + timeout=10, )