From 18d26002086fbb884f2f6c957dfdb421defafbd9 Mon Sep 17 00:00:00 2001 From: jaime Date: Thu, 12 Oct 2023 19:35:50 -0500 Subject: [PATCH] added slurm command and proper setting of variuables --- jarvis_util/shell/slurm_exec.py | 83 ++++++++++++++++++++++----------- 1 file changed, 56 insertions(+), 27 deletions(-) diff --git a/jarvis_util/shell/slurm_exec.py b/jarvis_util/shell/slurm_exec.py index e67c588..8b5f10f 100644 --- a/jarvis_util/shell/slurm_exec.py +++ b/jarvis_util/shell/slurm_exec.py @@ -9,10 +9,11 @@ from jarvis_util.shell.local_exec import LocalExec from .exec_info import ExecInfo, ExecType + class SlurmExec(LocalExec): """ - This class contains methods for executing a command in parallel - using MPI. + This class contains methods for executing a command + through the Slurm scheduler """ def __init__(self, cmd, exec_info): @@ -20,35 +21,60 @@ def __init__(self, cmd, exec_info): Execute a command through sbatch :param cmd: A command (string) to execute - :param exec_info: Information needed by MPI + :param exec_info: Information needed by sbatch """ self.cmd = cmd - self.job_name = exec_info + self.job_name = exec_info.job_name self.num_nodes = exec_info.num_noodes - self.nprocs = exec_info.nprocs self.ppn = exec_info.ppn - self.hostfile = exec_info.hostfile - self.mpi_env = exec_info.env + self.cpus_per_task = exec_info.cpus_per_task + self.time = exec_info.time + self.partition = exec_info.partition + self.mail_type = exec_info.mail_type + self.output = exec_info.pipe_stdout + self.error = exec_info.pipe_stderr + self.memory = exec_info.memory + self.gres = exec_info.gres + self.exclusive = exec_info.exclusive + super().__init__(self.slurmcmd(), exec_info.mod(env=exec_info.basic_env)) + def generate_sbatch_command(self): + cmd = "sbatch" + + # Mapping of attribute names to their corresponding sbatch option names + options_map = { + 'job_name': 'job-name', + 'num_nodes': 'nodes', + 'ppn': 'ntasks', + 'cpus_per_task': 'cpus-per-task', + 'time': 'time', + 'partition': 'partition', + 'mail_type': 'mail-type', + 'output': 'output', + 'error': 'error', + 'memory': 'mem', + 'gres': 'gres', + 'exclusive': 'exclusive' + } + + for attr, option in options_map.items(): + value = getattr(self, attr) + if value is not None: + if value is True: # For options like 'exclusive' that don't take a value + cmd += f" --{option}" + else: + cmd += f" --{option}={value}" + + cmd += f" {self.cmd}" + return cmd + def slurmcmd(self): - params = [f'mpirun -n {self.nprocs}'] - params.append('--oversubscribe') - if self.ppn is not None: - params.append(f'-ppn {self.ppn}') - if len(self.hostfile): - if self.hostfile.is_subset() or self.hostfile.path is None: - params.append(f'--host {",".join(self.hostfile.hosts)}') - else: - params.append(f'--hostfile {self.hostfile.path}') - params += [f'-x {key}=\"{val}\"' - for key, val in self.mpi_env.items()] - params.append(self.cmd) - cmd = ' '.join(params) + cmd = self.generate_sbatch_command() jutil = JutilManager.get_instance() - if jutil.debug_mpi_exec: + if jutil.debug_slurm: print(cmd) return cmd @@ -56,11 +82,14 @@ def slurmcmd(self): class SlurmExecInfo(ExecInfo): def __init__(self, job_name=None, num_nodes=1, **kwargs): super().__init__(exec_type=ExecType.SLURM, **kwargs) - allowed_options = ['cpus_per_task', 'time', 'partition', 'mail_type', 'mail_user', 'mem', 'gres', 'exclusive', 'pipeline_file'] + allowed_options = ['job_name', 'num_nodes', 'cpus_per_task', 'time', 'partition', 'mail_type', + 'mail_user', 'mem', 'gres', 'exclusive'] + self.keys += allowed_options + # We use ppn, and the output and error file from the base Exec Info self.job_name = job_name self.num_nodes = num_nodes - self.keys.append(['job_name', 'num_nodes']) - for key, value in kwargs.items(): - if key in allowed_options: - setattr(self, key, value) - self.keys.append(key) + for key in allowed_options: + if key in kwargs: + setattr(self, key, kwargs[key]) + else: + setattr(self, key, None)