Skip to content

Commit

Permalink
Merge pull request #25 from lukemartinlogan/master
Browse files Browse the repository at this point in the history
Add MPI factory
  • Loading branch information
lukemartinlogan authored Sep 16, 2023
2 parents 72da964 + fc55de6 commit 67884ab
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 18 deletions.
21 changes: 15 additions & 6 deletions jarvis_util/shell/exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .local_exec import LocalExec
from .pssh_exec import PsshExec
from .pssh_exec import SshExec
from .mpi_exec import MpiExec
from .mpi_exec import MpiVersion, MpichExec, OpenMpiExec
from .exec_info import ExecInfo, ExecType, Executable


Expand All @@ -26,14 +26,23 @@ def __init__(self, cmd, exec_info=None):
super().__init__()
if exec_info is None:
exec_info = ExecInfo()
if exec_info.exec_type == ExecType.LOCAL:
exec_type = exec_info.exec_type
if exec_type == ExecType.LOCAL:
self.exec_ = LocalExec(cmd, exec_info)
elif exec_info.exec_type == ExecType.SSH:
elif exec_type == ExecType.SSH:
self.exec_ = SshExec(cmd, exec_info)
elif exec_info.exec_type == ExecType.PSSH:
elif exec_type == ExecType.PSSH:
self.exec_ = PsshExec(cmd, exec_info)
elif exec_info.exec_type == ExecType.MPI:
self.exec_ = MpiExec(cmd, exec_info)
elif exec_type == ExecType.MPI:
exec_type = MpiVersion(exec_info).version

if exec_type == ExecType.MPICH:
self.exec_ = MpichExec(cmd, exec_info)
elif exec_type == ExecType.INTEL_MPI:
self.exec_ = MpichExec(cmd, exec_info)
elif exec_type == ExecType.OPENMPI:
self.exec_ = OpenMpiExec(cmd, exec_info)

self.set_exit_code()
self.set_output()

Expand Down
3 changes: 3 additions & 0 deletions jarvis_util/shell/exec_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ class ExecType(Enum):
SSH = 'SSH'
PSSH = 'PSSH'
MPI = 'MPI'
MPICH = 'MPICH'
OPENMPI = 'OPENMPI'
INTEL_MPI = 'INTEL_MPI'


class ExecInfo:
Expand Down
85 changes: 73 additions & 12 deletions jarvis_util/shell/mpi_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,73 @@
from .exec_info import ExecInfo, ExecType


class MpiExec(LocalExec):
class MpiVersion(LocalExec):
"""
Introspect the current MPI implementation from the machine using
mpirun --version
"""

def __init__(self, exec_info):
self.cmd = 'mpirun --version'
super().__init__(self.cmd,
exec_info.mod(env=exec_info.basic_env,
collect_output=True,
hide_output=True))
vinfo = self.stdout
if 'mpich' in vinfo.lower():
self.version = ExecType.MPICH
elif 'Open MPI' in vinfo:
self.version = ExecType.OPENMPI
elif 'Intel(R) MPI Library' in vinfo:
# NOTE(llogan): similar to MPICH
self.version = ExecType.INTEL_MPI
else:
raise Exception(f'Could not identify MPI implementation: {vinfo}')


class OpenMpiExec(LocalExec):
"""
This class contains methods for executing a command in parallel
using MPI.
"""

def __init__(self, cmd, exec_info):
"""
Execute a command using MPI
:param cmd: A command (string) to execute
:param exec_info: Information needed by MPI
"""

self.cmd = cmd
self.nprocs = exec_info.nprocs
self.ppn = exec_info.ppn
self.hostfile = exec_info.hostfile
self.mpi_env = exec_info.env
super().__init__(self.mpicmd(),
exec_info.mod(env=exec_info.basic_env))

def mpicmd(self):
params = [f'mpirun -n {self.nprocs}']
params.append('--oversubscribe')
if self.ppn is not None:
params.append(f'-ppn {self.ppn}')
if len(self.hostfile):
if self.hostfile.is_subset() or self.hostfile.path is None:
params.append(f'--host {",".join(self.hostfile.hosts)}')
else:
params.append(f'--hostfile {self.hostfile.path}')
params += [f'-x {key}=\"{val}\"'
for key, val in self.mpi_env.items()]
params.append(self.cmd)
cmd = ' '.join(params)
jutil = JutilManager.get_instance()
if jutil.debug_mpi_exec:
print(cmd)
return cmd


class MpichExec(LocalExec):
"""
This class contains methods for executing a command in parallel
using MPI.
Expand All @@ -33,23 +99,18 @@ def __init__(self, cmd, exec_info):
exec_info.mod(env=exec_info.basic_env))

def mpicmd(self):
# NOTE(llogan):
# Use -x instead of -genv for openmpi
# Use --oversubscribe for openmpi
params = [f"mpirun -n {self.nprocs}"]
# params.append('--oversubscribe')
params = [f'mpirun -n {self.nprocs}']
if self.ppn is not None:
params.append(f"-ppn {self.ppn}")
params.append(f'-ppn {self.ppn}')
if len(self.hostfile):
if self.hostfile.is_subset() or self.hostfile.path is None:
params.append(f"--host {','.join(self.hostfile.hosts)}")
params.append(f'--host {",".join(self.hostfile.hosts)}')
else:
params.append(f"--hostfile {self.hostfile.path}")
params += [f"-genv {key}=\"{val}\""
params.append(f'--hostfile {self.hostfile.path}')
params += [f'-genv {key}=\"{val}\"'
for key, val in self.mpi_env.items()]
# params += [f"-x {key}={val}" for key, val in self.mpi_env.items()]
params.append(self.cmd)
cmd = " ".join(params)
cmd = ' '.join(params)
jutil = JutilManager.get_instance()
if jutil.debug_mpi_exec:
print(cmd)
Expand Down
14 changes: 14 additions & 0 deletions test/unit/test_mpi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from jarvis_util.util.argparse import ArgParse
from jarvis_util.shell.exec import Exec
from jarvis_util.shell.local_exec import LocalExecInfo
from jarvis_util.shell.mpi_exec import MpiVersion
from jarvis_util.util.size_conv import SizeConv
import pathlib
import itertools
from unittest import TestCase


class TestSystemInfo(TestCase):
def test_mpi(self):
info = MpiVersion(LocalExecInfo())
print(f'MPI VERSION: {info.version}')

0 comments on commit 67884ab

Please sign in to comment.