Skip to content

Commit

Permalink
Tidy code, address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
JDBetteridge committed Nov 28, 2022
1 parent ba86a16 commit 4c439ea
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 40 deletions.
72 changes: 36 additions & 36 deletions pyop2/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,18 @@
from petsc4py import PETSc
from mpi4py import MPI # noqa
import atexit
import gc
import glob
import os
import tempfile

from pyop2.configuration import configuration
from pyop2.exceptions import CompilationError
from pyop2.logger import warning, debug, logger, DEBUG
from pyop2.utils import trim


__all__ = ("COMM_WORLD", "COMM_SELF", "MPI", "internal_comm", "is_pyop2_comm", "incref", "decref", "PyOP2Comm")
__all__ = ("COMM_WORLD", "COMM_SELF", "MPI", "internal_comm", "is_pyop2_comm", "incref", "decref", "temp_internal_comm")

# These are user-level communicators, we never send any messages on
# them inside PyOP2.
Expand Down Expand Up @@ -113,7 +117,7 @@ def delcomm_outer(comm, keyval, icomm):
:arg comm: Outer communicator.
:arg keyval: The MPI keyval, should be ``innercomm_keyval``.
:arg icomm: The inner communicator, should have a reference to
``comm`.
``comm``.
"""
if keyval != innercomm_keyval:
raise ValueError("Unexpected keyval")
Expand Down Expand Up @@ -143,17 +147,17 @@ def delcomm_outer(comm, keyval, icomm):


def is_pyop2_comm(comm):
"""Returns `True` if `comm` is a PyOP2 communicator,
"""Returns ``True`` if ``comm`` is a PyOP2 communicator,
False if `comm` another communicator.
Raises exception if `comm` is not a communicator.
Raises exception if ``comm`` is not a communicator.
:arg comm: Communicator to query
"""
global PYOP2_FINALIZED
if isinstance(comm, PETSc.Comm):
ispyop2comm = False
elif comm == MPI.COMM_NULL:
if PYOP2_FINALIZED is False:
if not PYOP2_FINALIZED:
raise ValueError("Communicator passed to is_pyop2_comm() is COMM_NULL")
else:
ispyop2comm = True
Expand Down Expand Up @@ -182,51 +186,54 @@ def pyop2_comm_status():
return status_string


class PyOP2Comm:
class temp_internal_comm:
""" Use a PyOP2 internal communicator and
increment and decrement the internal comm.
:arg comm: Any communicator
"""
def __init__(self, comm):
self.user_comm = comm
self.internal_comm = None
self.internal_comm = internal_comm(self.user_comm)

def __del__(self):
decref(self.internal_comm)

def __enter__(self):
""" Returns an internal comm tat will be safely decref'd
when leaving the context manager
""" Returns an internal comm that will be safely decref'd
when the context manager is destroyed
:returns pyop2_comm: A PyOP2 internal communicator
"""
self.internal_comm = internal_comm(self.user_comm)
return self.internal_comm

def __exit__(self, exc_type, exc_value, traceback):
decref(self.internal_comm)
self.internal_comm = None
pass


def internal_comm(comm):
""" Creates an internal comm from the user comm
""" Creates an internal comm from the user comm.
If comm is None, create an internal communicator from COMM_WORLD
:arg comm: A communicator or None
:returns pyop2_comm: A PyOP2 internal communicator
"""
# Parse inputs
if comm is None:
# None will be the default when creating most objects
pyop2_comm = dup_comm(COMM_WORLD)
elif is_pyop2_comm(comm):
# Increase the reference count and return same comm if
# already an internal communicator
incref(comm)
pyop2_comm = comm
comm = COMM_WORLD
elif isinstance(comm, PETSc.Comm):
# Convert PETSc.Comm to mpi4py.MPI.Comm
pyop2_comm = dup_comm(comm.tompi4py())
elif comm == MPI.COMM_NULL:
# Ensure comm is not the NULL communicator
comm = comm.tompi4py()

# Check for invalid inputs
if comm == MPI.COMM_NULL:
raise ValueError("MPI_COMM_NULL passed to internal_comm()")
elif not isinstance(comm, MPI.Comm):
raise ValueError("Don't know how to dup a %r" % type(comm))

# Handle a valid input
if is_pyop2_comm(comm):
incref(comm)
pyop2_comm = comm
else:
pyop2_comm = dup_comm(comm)
return pyop2_comm
Expand All @@ -243,7 +250,6 @@ def incref(comm):
def decref(comm):
""" Decrement communicator reference count
"""
global PYOP2_FINALIZED
if not PYOP2_FINALIZED:
assert is_pyop2_comm(comm)
refcount = comm.Get_attr(refcount_keyval)
Expand Down Expand Up @@ -287,7 +293,7 @@ def dup_comm(comm_in):
@collective
def create_split_comm(comm):
""" Create a split communicator based on either shared memory access
if using MPI >= 3, or shared local disk access if using MPI >= 3.
if using MPI >= 3, or shared local disk access if using MPI <= 3.
Used internally for creating compilation communicators
:arg comm: A communicator to split
Expand All @@ -300,7 +306,6 @@ def create_split_comm(comm):
debug("Finished creating compilation communicator using MPI_Split_type")
else:
debug("Creating compilation communicator using MPI_Split + filesystem")
import tempfile
if comm.rank == 0:
if not os.path.exists(configuration["cache_dir"]):
os.makedirs(configuration["cache_dir"], exist_ok=True)
Expand All @@ -316,7 +321,6 @@ def create_split_comm(comm):
with open(os.path.join(tmpname, str(comm.rank)), "wb"):
pass
comm.barrier()
import glob
ranks = sorted(int(os.path.basename(name))
for name in glob.glob("%s/[0-9]*" % tmpname))
debug("Creating compilation communicator using filesystem colors")
Expand All @@ -335,8 +339,8 @@ def get_compilation_comm(comm):


def set_compilation_comm(comm, comp_comm):
"""Stash the compilation communicator (`comp_comm`) on the
PyOP2 communicator `comm`
"""Stash the compilation communicator (``comp_comm``) on the
PyOP2 communicator ``comm``
:arg comm: A PyOP2 Communicator
:arg comp_comm: The compilation communicator
Expand Down Expand Up @@ -435,17 +439,13 @@ def _free_comms():
debug("PyOP2 Finalizing")
# Collect garbage as it may hold on to communicator references
debug("Calling gc.collect()")
import gc
gc.collect()
debug(pyop2_comm_status())
debug(f"Freeing comms in list (length {len(dupped_comms)})")
while dupped_comms:
c = dupped_comms[-1]
if is_pyop2_comm(c):
refcount = c.Get_attr(refcount_keyval)
debug(f"Freeing {c.name}, which has refcount {refcount[0]}")
else:
debug("Freeing non PyOP2 comm in `_free_comms()`")
refcount = c.Get_attr(refcount_keyval)
debug(f"Freeing {c.name}, which has refcount {refcount[0]}")
free_comm(c)
for kv in [refcount_keyval,
innercomm_keyval,
Expand All @@ -459,7 +459,7 @@ def _free_comms():
def hash_comm(comm):
"""Return a hashable identifier for a communicator."""
if not is_pyop2_comm(comm):
ValueError("`comm` passed to `hash_comm()` must be a PyOP2 communicator")
raise ValueError("`comm` passed to `hash_comm()` must be a PyOP2 communicator")
# `comm` must be a PyOP2 communicator so we can use its id()
# as the hash and this is stable between invocations.
return id(comm)
Expand Down
3 changes: 2 additions & 1 deletion pyop2/types/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,8 @@ def __init__(self, arg, dims=None):
return
self._dsets = arg
try:
# Try/except may not be necessary, someone needs to think about this...
# Try to choose the comm to be the same as the first set
# of the MixedDataSet
comm = self._process_args(arg, dims)[0][0].comm
except AttributeError:
comm = None
Expand Down
6 changes: 3 additions & 3 deletions pyop2/types/mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -1113,7 +1113,7 @@ def mult(self, mat, x, y):
a[0] = x.array_r
else:
x.array_r
with mpi.PyOP2Comm(x.comm) as comm:
with mpi.temp_internal_comm(x.comm) as comm:
comm.bcast(a)
return y.scale(a)
else:
Expand All @@ -1130,7 +1130,7 @@ def multTranspose(self, mat, x, y):
a[0] = x.array_r
else:
x.array_r
with mpi.PyOP2Comm(x.comm) as comm:
with mpi.temp_internal_comm(x.comm) as comm:
comm.bcast(a)
y.scale(a)
else:
Expand All @@ -1155,7 +1155,7 @@ def multTransposeAdd(self, mat, x, y, z):
a[0] = x.array_r
else:
x.array_r
with mpi.PyOP2Comm(x.comm) as comm:
with mpi.temp_internal_comm(x.comm) as comm:
comm.bcast(a)
if y == z:
# Last two arguments are aliased.
Expand Down

0 comments on commit 4c439ea

Please sign in to comment.