diff --git a/jarvis_util/shell/exec.py b/jarvis_util/shell/exec.py index a8c63bf..a9792ce 100644 --- a/jarvis_util/shell/exec.py +++ b/jarvis_util/shell/exec.py @@ -7,7 +7,7 @@ from .local_exec import LocalExec from .pssh_exec import PsshExec from .pssh_exec import SshExec -from .mpi_exec import MpiVersion, MpichExec, OpenMpiExec +from .mpi_exec import MpiVersion, MpichExec, OpenMpiExec, CrayMpichExec from .exec_info import ExecInfo, ExecType, Executable @@ -47,6 +47,8 @@ def __init__(self, cmd, exec_info=None): self.exec_ = MpichExec(cmd, exec_info) elif exec_type == ExecType.OPENMPI: self.exec_ = OpenMpiExec(cmd, exec_info) + elif exec_type == Exectype.CRAY_MPICH: + self.exec_ = CrayMpichExec(cmd, exec_info) self.set_exit_code() self.set_output() diff --git a/jarvis_util/shell/exec_info.py b/jarvis_util/shell/exec_info.py index 4c7a8d2..d60e90a 100644 --- a/jarvis_util/shell/exec_info.py +++ b/jarvis_util/shell/exec_info.py @@ -25,6 +25,7 @@ class ExecType(Enum): INTEL_MPI = 'INTEL_MPI' SLURM = 'SLURM' PBS = 'PBS' + CRAY_MPICH = 'CRAY_MPICH' class ExecInfo: diff --git a/jarvis_util/shell/mpi_exec.py b/jarvis_util/shell/mpi_exec.py index 0e55d8f..22ecb9e 100644 --- a/jarvis_util/shell/mpi_exec.py +++ b/jarvis_util/shell/mpi_exec.py @@ -30,6 +30,8 @@ def __init__(self, exec_info): elif 'Intel(R) MPI Library' in vinfo: # NOTE(llogan): similar to MPICH self.version = ExecType.INTEL_MPI + elif 'mpiexec version' in vinfo: + self.version = ExecType.CRAY_MPICH else: raise Exception(f'Could not identify MPI implementation: {vinfo}') @@ -116,6 +118,45 @@ def mpicmd(self): print(cmd) return cmd +class CrayMpichExec(LocalExec): + """ + This class contains methods for executing a command in parallel + using MPI. + """ + + def __init__(self, cmd, exec_info): + """ + Execute a command using MPI + + :param cmd: A command (string) to execute + :param exec_info: Information needed by MPI + """ + + self.cmd = cmd + self.nprocs = exec_info.nprocs + self.ppn = exec_info.ppn + self.hostfile = exec_info.hostfile + self.mpi_env = exec_info.env + super().__init__(self.mpicmd(), + exec_info.mod(env=exec_info.basic_env)) + + def mpicmd(self): + params = [f'mpirun -n {self.nprocs}'] + if self.ppn is not None: + params.append(f'-ppn {self.ppn}') + if len(self.hostfile): + if self.hostfile.is_subset() or self.hostfile.path is None: + params.append(f'--host {",".join(self.hostfile.hosts)}') + else: + params.append(f'--hostfile {self.hostfile.path}') + params += [f'--env {key}=\"{val}\"' + for key, val in self.mpi_env.items()] + params.append(self.cmd) + cmd = ' '.join(params) + jutil = JutilManager.get_instance() + if jutil.debug_mpi_exec: + print(cmd) + return cmd class MpiExecInfo(ExecInfo): def __init__(self, **kwargs):