Skip to content

Commit

Permalink
Merge pull request #676 from OP2/JDBetteridge/clean_comms
Browse files Browse the repository at this point in the history
JDBetteridge/clean comms
  • Loading branch information
ksagiyam authored Nov 29, 2022
2 parents 7a3c68f + 4c439ea commit 7aca5e5
Show file tree
Hide file tree
Showing 13 changed files with 435 additions and 192 deletions.
77 changes: 19 additions & 58 deletions pyop2/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@
from packaging.version import Version, InvalidVersion


from pyop2.mpi import MPI, collective, COMM_WORLD
from pyop2.mpi import dup_comm, get_compilation_comm, set_compilation_comm
from pyop2 import mpi
from pyop2.configuration import configuration
from pyop2.logger import warning, debug, progress, INFO
from pyop2.exceptions import CompilationError
Expand All @@ -59,7 +58,7 @@ def _check_hashes(x, y, datatype):
return False


_check_op = MPI.Op.Create(_check_hashes, commute=True)
_check_op = mpi.MPI.Op.Create(_check_hashes, commute=True)
_compiler = None


Expand Down Expand Up @@ -148,53 +147,6 @@ def sniff_compiler(exe):
return compiler


@collective
def compilation_comm(comm):
"""Get a communicator for compilation.
:arg comm: The input communicator.
:returns: A communicator used for compilation (may be smaller)
"""
# Should we try and do node-local compilation?
if not configuration["node_local_compilation"]:
return comm
retcomm = get_compilation_comm(comm)
if retcomm is not None:
debug("Found existing compilation communicator")
return retcomm
if MPI.VERSION >= 3:
debug("Creating compilation communicator using MPI_Split_type")
retcomm = comm.Split_type(MPI.COMM_TYPE_SHARED)
debug("Finished creating compilation communicator using MPI_Split_type")
set_compilation_comm(comm, retcomm)
return retcomm
debug("Creating compilation communicator using MPI_Split + filesystem")
import tempfile
if comm.rank == 0:
if not os.path.exists(configuration["cache_dir"]):
os.makedirs(configuration["cache_dir"], exist_ok=True)
tmpname = tempfile.mkdtemp(prefix="rank-determination-",
dir=configuration["cache_dir"])
else:
tmpname = None
tmpname = comm.bcast(tmpname, root=0)
if tmpname is None:
raise CompilationError("Cannot determine sharedness of filesystem")
# Touch file
debug("Made tmpdir %s" % tmpname)
with open(os.path.join(tmpname, str(comm.rank)), "wb"):
pass
comm.barrier()
import glob
ranks = sorted(int(os.path.basename(name))
for name in glob.glob("%s/[0-9]*" % tmpname))
debug("Creating compilation communicator using filesystem colors")
retcomm = comm.Split(color=min(ranks), key=comm.rank)
debug("Finished creating compilation communicator using filesystem colors")
set_compilation_comm(comm, retcomm)
return retcomm


class Compiler(ABC):
"""A compiler for shared libraries.
Expand All @@ -210,7 +162,7 @@ class Compiler(ABC):
:arg cpp: Should we try and use the C++ compiler instead of the C
compiler?.
:kwarg comm: Optional communicator to compile the code on
(defaults to COMM_WORLD).
(defaults to pyop2.mpi.COMM_WORLD).
"""
_name = "unknown"

Expand All @@ -226,16 +178,24 @@ class Compiler(ABC):
_debugflags = ()

def __init__(self, extra_compiler_flags=(), extra_linker_flags=(), cpp=False, comm=None):
# Get compiler version ASAP since it is used in __repr__
self.sniff_compiler_version()

self._extra_compiler_flags = tuple(extra_compiler_flags)
self._extra_linker_flags = tuple(extra_linker_flags)

self._cpp = cpp
self._debug = configuration["debug"]

# Ensure that this is an internal communicator.
comm = dup_comm(comm or COMM_WORLD)
self.comm = compilation_comm(comm)
self.sniff_compiler_version()
# Compilation communicators are reference counted on the PyOP2 comm
self.pcomm = mpi.internal_comm(comm)
self.comm = mpi.compilation_comm(self.pcomm)

def __del__(self):
if hasattr(self, "comm"):
mpi.decref(self.comm)
if hasattr(self, "pcomm"):
mpi.decref(self.pcomm)

def __repr__(self):
return f"<{self._name} compiler, version {self.version or 'unknown'}>"
Expand Down Expand Up @@ -313,7 +273,7 @@ def expandWl(ldflags):
else:
yield flag

@collective
@mpi.collective
def get_so(self, jitmodule, extension):
"""Build a shared library and load it
Expand Down Expand Up @@ -591,7 +551,7 @@ class AnonymousCompiler(Compiler):
_name = "Unknown"


@collective
@mpi.collective
def load(jitmodule, extension, fn_name, cppargs=(), ldargs=(),
argtypes=None, restype=None, comm=None):
"""Build a shared library and return a function pointer from it.
Expand All @@ -608,7 +568,7 @@ def load(jitmodule, extension, fn_name, cppargs=(), ldargs=(),
:arg restype: The return type of the function (optional, pass
``None`` for ``void``).
:kwarg comm: Optional communicator to compile the code on (only
rank 0 compiles code) (defaults to COMM_WORLD).
rank 0 compiles code) (defaults to pyop2.mpi.COMM_WORLD).
"""
from pyop2.global_kernel import GlobalKernel

Expand Down Expand Up @@ -639,6 +599,7 @@ def __init__(self, code, argtypes):
exe = configuration["cc"] or "mpicc"
compiler = sniff_compiler(exe)
dll = compiler(cppargs, ldargs, cpp=cpp, comm=comm).get_so(code, extension)

if isinstance(jitmodule, GlobalKernel):
_add_profiling_events(dll, code.local_kernel.events)

Expand Down
1 change: 1 addition & 0 deletions pyop2/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
handler = logging.StreamHandler()
logger.addHandler(handler)


debug = logger.debug
info = logger.info
warning = logger.warning
Expand Down
Loading

0 comments on commit 7aca5e5

Please sign in to comment.