Skip to content

Commit

Permalink
[Jobs] Track real job process (#31306)
Browse files Browse the repository at this point in the history
Resolves the issue described in #31274. On Linux systems, when a stop signal is sent, instead of killing + waiting on only the shell process (which starts the actual job as a child process), we want to kill all the children of the shell process along with the shell process itself, and poll all processes until they exit or send a force SIGKILL on timeout. (This change is compatible with Mac OSX systems as well)

Co-authored-by: shrekris-anyscale <[email protected]>
  • Loading branch information
2 people authored and AmeerHajAli committed Jan 12, 2023
1 parent af2c7b0 commit 4d56c79
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 41 deletions.
90 changes: 51 additions & 39 deletions dashboard/modules/job/job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import logging
import os
import psutil
import random
import signal
import string
Expand Down Expand Up @@ -325,6 +326,24 @@ async def _polling(self, child_process: subprocess.Popen) -> int:
# still running, yield control, 0.1s by default
await asyncio.sleep(self.SUBPROCESS_POLL_PERIOD_S)

async def _poll_all(self, processes: List[psutil.Process]):
"""Poll processes until all are completed."""
while True:
(_, alive) = psutil.wait_procs(processes, timeout=0)
if len(alive) == 0:
return
else:
await asyncio.sleep(self.SUBPROCESS_POLL_PERIOD_S)

def _kill_processes(self, processes: List[psutil.Process], sig: signal.Signals):
"""Ensure each process is already finished or send a kill signal."""
for proc in processes:
try:
os.kill(proc.pid, sig)
except ProcessLookupError:
# Process is already dead
pass

async def run(
self,
# Signal actor used in testing to capture PENDING -> RUNNING cases
Expand Down Expand Up @@ -374,60 +393,53 @@ async def run(
)
log_path = self._log_client.get_log_file_path(self._job_id)
child_process = self._exec_entrypoint(log_path)
child_pid = child_process.pid

polling_task = create_task(self._polling(child_process))
finished, _ = await asyncio.wait(
[polling_task, self._stop_event.wait()], return_when=FIRST_COMPLETED
)

if self._stop_event.is_set():
polling_task.cancel()
if sys.platform == "win32" and self._win32_job_object:
polling_task.cancel()
win32job.TerminateJobObject(self._win32_job_object, -1)
elif sys.platform != "win32":
stop_signal = os.environ.get("RAY_JOB_STOP_SIGNAL", "SIGTERM")
if stop_signal not in self.VALID_STOP_SIGNALS:
logger.warning(
f"{stop_signal} not a valid stop signal. Terminating "
"job with SIGTERM."
)
stop_signal = "SIGTERM"

job_process = psutil.Process(child_pid)
proc_to_kill = [job_process] + job_process.children(recursive=True)

# Send stop signal and wait for job to terminate gracefully,
# otherwise SIGKILL job forcefully after timeout.
self._kill_processes(proc_to_kill, getattr(signal, stop_signal))
try:
stop_signal = os.environ.get("RAY_JOB_STOP_SIGNAL", "SIGTERM")
if stop_signal not in self.VALID_STOP_SIGNALS:
logger.warning(
f"{stop_signal} not a valid stop signal. Terminating "
"job with SIGTERM."
stop_job_wait_time = int(
os.environ.get(
"RAY_JOB_STOP_WAIT_TIME_S",
self.DEFAULT_RAY_JOB_STOP_WAIT_TIME_S,
)
stop_signal = "SIGTERM"
os.killpg(
os.getpgid(child_process.pid),
getattr(signal, stop_signal),
)
except ProcessLookupError:
# Process already completed.
poll_job_stop_task = create_task(self._poll_all(proc_to_kill))
await asyncio.wait_for(poll_job_stop_task, stop_job_wait_time)
logger.info(
f"Job {self._job_id} completed on its own before it could "
"be manually terminated."
f"Job {self._job_id} has been terminated gracefully "
f"with {stop_signal}."
)
pass
else:
# Wait for job to exit gracefully, otherwise kill process
# forcefully after timeout.
try:
stop_job_wait_time = int(
os.environ.get(
"RAY_JOB_STOP_WAIT_TIME_S",
self.DEFAULT_RAY_JOB_STOP_WAIT_TIME_S,
)
)
await asyncio.wait_for(polling_task, stop_job_wait_time)
logger.info(
f"Job {self._job_id} has been terminated gracefully "
f"with {stop_signal}."
)
except asyncio.TimeoutError:
logger.warning(
f"Attempt to gracefully terminate job {self._job_id} "
f"through {stop_signal} has timed out after "
f"{stop_job_wait_time} seconds. Job is now being "
"force-killed."
)
polling_task.cancel()
child_process.kill()
except asyncio.TimeoutError:
logger.warning(
f"Attempt to gracefully terminate job {self._job_id} "
f"through {stop_signal} has timed out after "
f"{stop_job_wait_time} seconds. Job is now being "
"force-killed with SIGKILL."
)
self._kill_processes(proc_to_kill, signal.SIGKILL)
await self._job_info_client.put_status(self._job_id, JobStatus.STOPPED)
else:
# Child process finished execution and no stop event is set
Expand Down
15 changes: 13 additions & 2 deletions dashboard/modules/job/tests/test_job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,10 @@ async def test_stopped_job(self, job_manager):
job_manager.stop_job(job_id)

async for lines in job_manager.tail_job_logs(job_id):
assert all(s == "Waiting..." for s in lines.strip().split("\n"))
assert all(
s == "Waiting..." or s == "Terminated"
for s in lines.strip().split("\n")
)
print(lines, end="")

await async_wait_for_condition_async_predicate(
Expand Down Expand Up @@ -850,6 +853,14 @@ def handler(*args):

assert job_manager.stop_job(job_id) is True

with pytest.raises(RuntimeError):
await async_wait_for_condition_async_predicate(
check_job_stopped,
job_manager=job_manager,
job_id=job_id,
timeout=stop_timeout - 1,
)

await async_wait_for_condition(
lambda: "SIGTERM signal handled!" in job_manager.get_job_logs(job_id)
)
Expand All @@ -858,7 +869,7 @@ def handler(*args):
check_job_stopped,
job_manager=job_manager,
job_id=job_id,
timeout=stop_timeout + 10,
timeout=10,
)


Expand Down

0 comments on commit 4d56c79

Please sign in to comment.