From f7a1bcfd98b0cb004d137e9d853cbd8552e16e78 Mon Sep 17 00:00:00 2001 From: Jonathan Karlsen Date: Wed, 25 Sep 2024 11:53:57 +0200 Subject: [PATCH] fixup --- src/_ert/forward_model_runner/job.py | 144 ++++++++---------- .../reporting/interactive.py | 4 +- .../forward_model_runner/reporting/message.py | 6 +- 3 files changed, 66 insertions(+), 88 deletions(-) diff --git a/src/_ert/forward_model_runner/job.py b/src/_ert/forward_model_runner/job.py index 85beedaff06..494e81af4a9 100644 --- a/src/_ert/forward_model_runner/job.py +++ b/src/_ert/forward_model_runner/job.py @@ -11,10 +11,12 @@ from datetime import datetime as dt from pathlib import Path from subprocess import Popen, run -from typing import Dict, Generator, Optional, Sequence, Tuple +from typing import Dict, Generator, List, Optional, Sequence, Tuple, cast from psutil import AccessDenied, NoSuchProcess, Process, TimeoutExpired, ZombieProcess +from ert.config.forward_model_step import ForwardModelStepJSON + from .io import check_executable from .reporting.message import ( Exited, @@ -26,7 +28,7 @@ logger = logging.getLogger(__name__) -def killed_by_oom(pids: set[int]) -> bool: +def killed_by_oom(pids: Sequence[int]) -> bool: """Will try to detect if a process (or any of its descendants) was killed by the Linux OOM-killer. @@ -76,26 +78,28 @@ def killed_by_oom(pids: set[int]) -> bool: class Job: MEMORY_POLL_PERIOD = 5 # Seconds between memory polls - def __init__(self, job_data, index, sleep_interval=1) -> None: + def __init__( + self, job_data: ForwardModelStepJSON, index: int, sleep_interval: int = 1 + ) -> None: self.sleep_interval = sleep_interval - self.job_data: Dict[str, str] = job_data + self.job_data = job_data self.index = index self.std_err = job_data.get("stderr") self.std_out = job_data.get("stdout") - def run(self) -> Generator[Start | Exited | Running]: + def run(self) -> Generator[Start | Exited | Running | None]: try: for msg in self._run(): yield msg + except StopIteration as e: + raise e except Exception as e: yield Exited(self, exit_code=1).with_error(str(e)) def create_start_message_and_check_job_files(self) -> Start: start_message = Start(self) - errors = self._check_job_files() - - errors.extend(self._assert_arg_list()) + errors = [*self._check_job_files()] self._dump_exec_env() @@ -103,9 +107,9 @@ def create_start_message_and_check_job_files(self) -> Start: start_message = start_message.with_error("\n".join(errors)) return start_message - def _build_arg_list(self): + def _build_arg_list(self) -> List[str]: executable = self.job_data.get("executable") - + # assert executable is not None combined_arg_list = [executable] if arg_list := self.job_data.get("argList"): combined_arg_list += arg_list @@ -117,7 +121,7 @@ def _open_file_handles( io.TextIOWrapper | None, io.TextIOWrapper | None, io.TextIOWrapper | None ]: if self.job_data.get("stdin"): - stdin = open(self.job_data.get("stdin"), encoding="utf-8") # noqa + stdin = open(cast(Path, self.job_data.get("stdin")), encoding="utf-8") # noqa else: stdin = None @@ -141,18 +145,18 @@ def _open_file_handles( return (stdin, stdout, stderr) - def _create_environment(self) -> Dict: - environment = self.job_data.get("environment") - if environment is not None: - environment = {**os.environ, **environment} - return environment + def _create_environment(self) -> Optional[Dict[str, str]]: + combined_environment = None + if environment := self.job_data.get("environment"): + combined_environment = {**os.environ, **environment} + return combined_environment - def _run(self) -> contextlib.Generator[Start | Exited | Running]: + def _run(self) -> Generator[Start | Exited | Running | None]: start_message = self.create_start_message_and_check_job_files() yield start_message if not start_message.success(): - return + raise StopIteration() arg_list = self._build_arg_list() @@ -160,7 +164,7 @@ def _run(self) -> contextlib.Generator[Start | Exited | Running]: # stdin/stdout/stderr are closed at the end of this function target_file = self.job_data.get("target_file") - target_file_mtime: int = _get_target_file_ntime(target_file) + target_file_mtime: Optional[int] = _get_target_file_ntime(target_file) try: proc = Popen( @@ -177,7 +181,7 @@ def _run(self) -> contextlib.Generator[Start | Exited | Running]: ) ensure_file_handles_closed([stdin, stdout, stderr]) yield exited_message - return + raise StopIteration() from None exit_code = None @@ -201,20 +205,22 @@ def _run(self) -> contextlib.Generator[Start | Exited | Running]: try: exit_code = process.wait(timeout=self.MEMORY_POLL_PERIOD) except TimeoutExpired: - exited_msg = self.handle_process_timeout_and_create_exited_msg( - process, proc + potential_exited_msg = ( + self.handle_process_timeout_and_create_exited_msg(process, proc) ) fm_step_pids |= { int(child.pid) for child in process.children(recursive=True) } - if isinstance(exited_msg, Exited): - yield exited_msg - return + if isinstance(potential_exited_msg, Exited): + yield potential_exited_msg + + raise StopIteration() from None ensure_file_handles_closed([stdin, stdout, stderr]) exited_message = self._create_exited_message_based_on_exit_code( max_memory_usage, target_file_mtime, exit_code, fm_step_pids ) + assert exited_message.job yield exited_message def _create_exited_message_based_on_exit_code( @@ -224,20 +230,16 @@ def _create_exited_message_based_on_exit_code( exit_code: int, fm_step_pids: Sequence[int], ) -> Exited: - # exit_code = proc.returncode - if exit_code != 0: exited_message = self._create_exited_msg_for_non_zero_exit_code( max_memory_usage, exit_code, fm_step_pids ) return exited_message - # exit_code is 0 - + exited_message = Exited(self, exit_code) if self.job_data.get("error_file") and os.path.exists( self.job_data["error_file"] ): - exited_message = Exited(self, exit_code) return exited_message.with_error( f'Found the error file:{self.job_data["error_file"]} - job failed.' ) @@ -271,34 +273,33 @@ def _create_exited_msg_for_non_zero_exit_code( ) def handle_process_timeout_and_create_exited_msg( - self, process: Process, proc: Popen + self, process: Process, proc: Popen[Process] ) -> Exited | None: max_running_minutes = self.job_data.get("max_running_minutes") run_start_time = dt.now() run_time = dt.now() - run_start_time - if ( - max_running_minutes is not None - and run_time.seconds > max_running_minutes * 60 - ): - # If the spawned process is not in the same process group as - # the callee (job_dispatch), we will kill the process group - # explicitly. - # - # Propagating the unsuccessful Exited message will kill the - # callee group. See job_dispatch.py. - process_group_id = os.getpgid(proc.pid) - this_group_id = os.getpgid(os.getpid()) - if process_group_id != this_group_id: - os.killpg(process_group_id, signal.SIGKILL) - - return Exited(self, proc.returncode).with_error( - ( - f"Job:{self.name()} has been running " - f"for more than {max_running_minutes} " - "minutes - explicitly killed." - ) + if max_running_minutes is None or run_time.seconds > max_running_minutes * 60: + return None + + # If the spawned process is not in the same process group as + # the callee (job_dispatch), we will kill the process group + # explicitly. + # + # Propagating the unsuccessful Exited message will kill the + # callee group. See job_dispatch.py. + process_group_id = os.getpgid(proc.pid) + this_group_id = os.getpgid(os.getpid()) + if process_group_id != this_group_id: + os.killpg(process_group_id, signal.SIGKILL) + + return Exited(self, proc.returncode).with_error( + ( + f"Job:{self.name()} has been running " + f"for more than {max_running_minutes} " + "minutes - explicitly killed." ) + ) def _handle_process_io_error_and_create_exited_message( self, e: OSError, stderr: io.TextIOWrapper | None @@ -314,31 +315,6 @@ def _handle_process_io_error_and_create_exited_message( stderr.write(msg) return Exited(self, e.errno).with_error(msg) - def _assert_arg_list(self) -> list[str]: - errors: list[str] = [] - if "arg_types" in self.job_data: - arg_types = self.job_data["arg_types"] - arg_list = self.job_data.get("argList") - for index, arg_type in enumerate(arg_types): - if arg_type == "RUNTIME_FILE": - file_path = os.path.join(os.getcwd(), arg_list[index]) - if not os.path.isfile(file_path): - errors.append( - f"In job {self.name()}: RUNTIME_FILE {arg_list[index]} " - "does not exist." - ) - if arg_type == "RUNTIME_INT": - try: - int(arg_list[index]) - except ValueError: - errors.append( - ( - f"In job {self.name()}: argument with index {index} " - "is of incorrect type, should be integer." - ) - ) - return errors - def name(self) -> str: return self.job_data["name"] @@ -346,12 +322,12 @@ def _dump_exec_env(self) -> None: exec_env = self.job_data.get("exec_env") if exec_env: exec_name, _ = os.path.splitext( - os.path.basename(self.job_data.get("executable")) + os.path.basename(cast(Path, self.job_data.get("executable"))) ) with open(f"{exec_name}_exec_env.json", "w", encoding="utf-8") as f_handle: f_handle.write(json.dumps(exec_env, indent=4)) - def _check_job_files(self)-> list[str]: + def _check_job_files(self) -> list[str]: """ Returns the empty list if no failed checks, or a list of errors in case of failed checks. @@ -361,21 +337,23 @@ def _check_job_files(self)-> list[str]: errors.append(f'Could not locate stdin file: {self.job_data["stdin"]}') if self.job_data.get("start_file") and not os.path.exists( - self.job_data["start_file"] + cast(Path, self.job_data["start_file"]) ): errors.append(f'Could not locate start_file:{self.job_data["start_file"]}') if self.job_data.get("error_file") and os.path.exists( - self.job_data.get("error_file") + cast(Path, self.job_data.get("error_file")) ): - os.unlink(self.job_data.get("error_file")) + os.unlink(cast(Path, self.job_data.get("error_file"))) if executable_error := check_executable(self.job_data.get("executable")): errors.append(str(executable_error)) return errors - def _check_target_file_is_written(self, target_file_mtime: int, timeout: int =5) -> None | str: + def _check_target_file_is_written( + self, target_file_mtime: int, timeout: int = 5 + ) -> None | str: """ Check whether or not a target_file eventually appear. Returns None in case of success, an error message in the case of failure. diff --git a/src/_ert/forward_model_runner/reporting/interactive.py b/src/_ert/forward_model_runner/reporting/interactive.py index fd489c78378..25a26b37896 100644 --- a/src/_ert/forward_model_runner/reporting/interactive.py +++ b/src/_ert/forward_model_runner/reporting/interactive.py @@ -19,12 +19,12 @@ def _report(msg: Message) -> Optional[str]: "OK" if msg.success() else _JOB_EXIT_FAILED_STRING.format( - job_name=msg.job.name(), + job_name=msg.job.name() if msg.job else "NO_NAME", exit_code="No Code", error_message=msg.error_message, ) ) - return f"Running job: {msg.job.name()} ... " + return f"Running job: {msg.job.name() if msg.job else "None"} ... " def report(self, msg: Message): _msg = self._report(msg) diff --git a/src/_ert/forward_model_runner/reporting/message.py b/src/_ert/forward_model_runner/reporting/message.py index 3efc9380392..2811488da29 100644 --- a/src/_ert/forward_model_runner/reporting/message.py +++ b/src/_ert/forward_model_runner/reporting/message.py @@ -81,7 +81,7 @@ def with_error(self, message: str): self.error_message = message return self - def success(self): + def success(self) -> bool: return self.error_message is None @@ -116,7 +116,7 @@ def __init__(self): class Start(Message): - def __init__(self, job): + def __init__(self, job: "Job"): super().__init__(job) @@ -127,7 +127,7 @@ def __init__(self, job: "Job", memory_status: ProcessTreeStatus): class Exited(Message): - def __init__(self, job, exit_code): + def __init__(self, job, exit_code: int): super().__init__(job) self.exit_code = exit_code