From 1f0a740d596ca6359ca98bdad3079208e72587be Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Fri, 30 Sep 2022 13:10:08 +0100 Subject: [PATCH 01/17] Replaced tompi4py() with proper comm_dup() calls --- pyop2/types/mat.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/pyop2/types/mat.py b/pyop2/types/mat.py index de89b1421..87e79c9e6 100644 --- a/pyop2/types/mat.py +++ b/pyop2/types/mat.py @@ -1094,7 +1094,10 @@ def mult(self, mat, x, y): a[0] = x.array_r else: x.array_r - x.comm.tompi4py().bcast(a) + + comm = mpi.dup_comm(x.comm) + comm.bcast(a) + mpi.free_comm(comm) return y.scale(a) else: return v.pointwiseMult(x, y) @@ -1110,7 +1113,9 @@ def multTranspose(self, mat, x, y): a[0] = x.array_r else: x.array_r - x.comm.tompi4py().bcast(a) + comm = mpi.dup_comm(x.comm) + comm.bcast(a) + mpi.free_comm(comm) y.scale(a) else: v.pointwiseMult(x, y) @@ -1134,7 +1139,9 @@ def multTransposeAdd(self, mat, x, y, z): a[0] = x.array_r else: x.array_r - x.comm.tompi4py().bcast(a) + comm = mpi.dup_comm(x.comm) + comm.bcast(a) + mpi.free_comm(comm) if y == z: # Last two arguments are aliased. tmp = y.duplicate() From 65fd5062d82a7a7ce79b8204991a9937215f56ae Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Sun, 2 Oct 2022 22:29:23 +0100 Subject: [PATCH 02/17] WIP: Tests passing, but not all comms freed --- pyop2/caching.py | 5 +- pyop2/compilation.py | 85 +++------ pyop2/logger.py | 4 + pyop2/mpi.py | 367 +++++++++++++++++++++++++++++------- pyop2/op2.py | 5 +- pyop2/parloop.py | 10 +- pyop2/sparsity.pyx | 2 +- pyop2/types/dat.py | 12 +- pyop2/types/data_carrier.py | 1 + pyop2/types/dataset.py | 28 ++- pyop2/types/glob.py | 36 ++-- pyop2/types/map.py | 14 +- pyop2/types/mat.py | 73 ++++--- pyop2/types/set.py | 31 ++- test/unit/test_caching.py | 8 +- 15 files changed, 489 insertions(+), 192 deletions(-) diff --git a/pyop2/caching.py b/pyop2/caching.py index 24a3f5513..28ee74a9a 100644 --- a/pyop2/caching.py +++ b/pyop2/caching.py @@ -41,7 +41,7 @@ import cachetools from pyop2.configuration import configuration -from pyop2.mpi import hash_comm +from pyop2.mpi import hash_comm, is_pyop2_comm from pyop2.utils import cached_property @@ -274,6 +274,9 @@ def wrapper(*args, **kwargs): if collective: comm, disk_key = key(*args, **kwargs) disk_key = _as_hexdigest(disk_key) + # ~ k = id(comm), disk_key + # ~ if not is_pyop2_comm(comm): + # ~ import pytest; pytest.set_trace() k = hash_comm(comm), disk_key else: k = _as_hexdigest(key(*args, **kwargs)) diff --git a/pyop2/compilation.py b/pyop2/compilation.py index ecca43187..831c775e8 100644 --- a/pyop2/compilation.py +++ b/pyop2/compilation.py @@ -44,13 +44,12 @@ from packaging.version import Version, InvalidVersion -from pyop2.mpi import MPI, collective, COMM_WORLD -from pyop2.mpi import dup_comm, get_compilation_comm, set_compilation_comm +from pyop2 import mpi from pyop2.configuration import configuration from pyop2.logger import warning, debug, progress, INFO from pyop2.exceptions import CompilationError from petsc4py import PETSc - +from pyop2.logger import debug def _check_hashes(x, y, datatype): """MPI reduction op to check if code hashes differ across ranks.""" @@ -59,7 +58,7 @@ def _check_hashes(x, y, datatype): return False -_check_op = MPI.Op.Create(_check_hashes, commute=True) +_check_op = mpi.MPI.Op.Create(_check_hashes, commute=True) _compiler = None @@ -148,53 +147,6 @@ def sniff_compiler(exe): return compiler -@collective -def compilation_comm(comm): - """Get a communicator for compilation. - - :arg comm: The input communicator. - :returns: A communicator used for compilation (may be smaller) - """ - # Should we try and do node-local compilation? - if not configuration["node_local_compilation"]: - return comm - retcomm = get_compilation_comm(comm) - if retcomm is not None: - debug("Found existing compilation communicator") - return retcomm - if MPI.VERSION >= 3: - debug("Creating compilation communicator using MPI_Split_type") - retcomm = comm.Split_type(MPI.COMM_TYPE_SHARED) - debug("Finished creating compilation communicator using MPI_Split_type") - set_compilation_comm(comm, retcomm) - return retcomm - debug("Creating compilation communicator using MPI_Split + filesystem") - import tempfile - if comm.rank == 0: - if not os.path.exists(configuration["cache_dir"]): - os.makedirs(configuration["cache_dir"], exist_ok=True) - tmpname = tempfile.mkdtemp(prefix="rank-determination-", - dir=configuration["cache_dir"]) - else: - tmpname = None - tmpname = comm.bcast(tmpname, root=0) - if tmpname is None: - raise CompilationError("Cannot determine sharedness of filesystem") - # Touch file - debug("Made tmpdir %s" % tmpname) - with open(os.path.join(tmpname, str(comm.rank)), "wb"): - pass - comm.barrier() - import glob - ranks = sorted(int(os.path.basename(name)) - for name in glob.glob("%s/[0-9]*" % tmpname)) - debug("Creating compilation communicator using filesystem colors") - retcomm = comm.Split(color=min(ranks), key=comm.rank) - debug("Finished creating compilation communicator using filesystem colors") - set_compilation_comm(comm, retcomm) - return retcomm - - class Compiler(ABC): """A compiler for shared libraries. @@ -210,7 +162,7 @@ class Compiler(ABC): :arg cpp: Should we try and use the C++ compiler instead of the C compiler?. :kwarg comm: Optional communicator to compile the code on - (defaults to COMM_WORLD). + (defaults to pyop2.mpi.COMM_WORLD). """ _name = "unknown" @@ -226,16 +178,27 @@ class Compiler(ABC): _debugflags = () def __init__(self, extra_compiler_flags=(), extra_linker_flags=(), cpp=False, comm=None): + self.sniff_compiler_version() self._extra_compiler_flags = tuple(extra_compiler_flags) self._extra_linker_flags = tuple(extra_linker_flags) self._cpp = cpp self._debug = configuration["debug"] - # Ensure that this is an internal communicator. - comm = dup_comm(comm or COMM_WORLD) - self.comm = compilation_comm(comm) + # Compilation communicators are reference counted on the PyOP2 comm + self.pcomm = mpi.internal_comm(comm) + self.comm = mpi.compilation_comm(self.pcomm) self.sniff_compiler_version() + debug(f"INIT {self.__class__} and assign {self.comm.name}") + debug(f"INIT {self.__class__} and assign {self.pcomm.name}") + + def __del__(self): + if hasattr(self, "comm"): + debug(f"DELETE {self.__class__} and removing reference to {self.comm.name}") + mpi.decref(self.comm) + if hasattr(self, "pcomm"): + debug(f"DELETE {self.__class__} and removing reference to {self.pcomm.name}") + mpi.decref(self.pcomm) def __repr__(self): return f"<{self._name} compiler, version {self.version or 'unknown'}>" @@ -313,7 +276,7 @@ def expandWl(ldflags): else: yield flag - @collective + @mpi.collective def get_so(self, jitmodule, extension): """Build a shared library and load it @@ -445,6 +408,8 @@ def get_so(self, jitmodule, extension): # Atomically ensure soname exists os.rename(tmpname, soname) # Wait for compilation to complete + if self.comm == mpi.MPI.COMM_NULL: + import pytest; pytest.set_trace() self.comm.barrier() # Load resulting library return ctypes.CDLL(soname) @@ -591,7 +556,7 @@ class AnonymousCompiler(Compiler): _name = "Unknown" -@collective +@mpi.collective def load(jitmodule, extension, fn_name, cppargs=(), ldargs=(), argtypes=None, restype=None, comm=None): """Build a shared library and return a function pointer from it. @@ -608,7 +573,7 @@ def load(jitmodule, extension, fn_name, cppargs=(), ldargs=(), :arg restype: The return type of the function (optional, pass ``None`` for ``void``). :kwarg comm: Optional communicator to compile the code on (only - rank 0 compiles code) (defaults to COMM_WORLD). + rank 0 compiles code) (defaults to pyop2.mpi.COMM_WORLD). """ from pyop2.global_kernel import GlobalKernel @@ -638,7 +603,9 @@ def __init__(self, code, argtypes): else: exe = configuration["cc"] or "mpicc" compiler = sniff_compiler(exe) - dll = compiler(cppargs, ldargs, cpp=cpp, comm=comm).get_so(code, extension) + x = compiler(cppargs, ldargs, cpp=cpp, comm=comm) + dll = x.get_so(code, extension) + del x if isinstance(jitmodule, GlobalKernel): _add_profiling_events(dll, code.local_kernel.events) diff --git a/pyop2/logger.py b/pyop2/logger.py index fb6532746..833eeb8c2 100644 --- a/pyop2/logger.py +++ b/pyop2/logger.py @@ -40,6 +40,10 @@ handler = logging.StreamHandler() logger.addHandler(handler) +fhandler = logging.FileHandler('pyop2.log') +logger.addHandler(fhandler) + + debug = logger.debug info = logger.info warning = logger.warning diff --git a/pyop2/mpi.py b/pyop2/mpi.py index 1ee16c11d..cb48efc60 100644 --- a/pyop2/mpi.py +++ b/pyop2/mpi.py @@ -37,16 +37,23 @@ from petsc4py import PETSc from mpi4py import MPI # noqa import atexit +import inspect # remove later +from pyop2.configuration import configuration +from pyop2.logger import warning, debug, progress, INFO from pyop2.utils import trim -__all__ = ("COMM_WORLD", "COMM_SELF", "MPI", "dup_comm") +__all__ = ("COMM_WORLD", "COMM_SELF", "MPI", "internal_comm", "is_pyop2_comm", "incref", "decref") # These are user-level communicators, we never send any messages on # them inside PyOP2. COMM_WORLD = PETSc.COMM_WORLD.tompi4py() +COMM_WORLD.Set_name("PYOP2_COMM_WORLD") COMM_SELF = PETSc.COMM_SELF.tompi4py() +COMM_SELF.Set_name("PYOP2_COMM_SELF") + +PYOP2_FINALIZED = False # Exposition: # @@ -90,6 +97,15 @@ # outstanding duplicated communicators. +def collective(fn): + extra = trim(""" + This function is logically collective over MPI ranks, it is an + error to call it on fewer than all the ranks in MPI communicator. + """) + fn.__doc__ = "%s\n\n%s" % (trim(fn.__doc__), extra) if fn.__doc__ else extra + return fn + + def delcomm_outer(comm, keyval, icomm): """Deleter for internal communicator, removes reference to outer comm. @@ -118,52 +134,204 @@ def delcomm_outer(comm, keyval, icomm): # Outer communicator attribute (attaches user comm to inner communicator) outercomm_keyval = MPI.Comm.Create_keyval() +# Comm used for compilation, stashed on the internal communicator +compilationcomm_keyval = MPI.Comm.Create_keyval() + # List of internal communicators, must be freed at exit. dupped_comms = [] -def dup_comm(comm_in=None): - """Given a communicator return a communicator for internal use. +class FriendlyCommNull: + def __init__(self): + self.name = 'PYOP2_FRIENDLY_COMM_NULL' - :arg comm_in: Communicator to duplicate. If not provided, - defaults to COMM_WORLD. + def Get_attr(self, keyval): + return [1] - :returns: An mpi4py communicator.""" - if comm_in is None: - comm_in = COMM_WORLD - if isinstance(comm_in, PETSc.Comm): - comm_in = comm_in.tompi4py() - elif not isinstance(comm_in, MPI.Comm): - raise ValueError("Don't know how to dup a %r" % type(comm_in)) - if comm_in == MPI.COMM_NULL: - return comm_in - refcount = comm_in.Get_attr(refcount_keyval) - if refcount is not None: - # Passed an existing PyOP2 comm, return it - comm_out = comm_in - refcount[0] += 1 + def Free(self): + pass + + +def is_pyop2_comm(comm): + """Returns `True` if `comm` is a PyOP2 communicator, + False if `comm` another communicator. + Raises exception if `comm` is not a communicator. + + :arg comm: Communicator to query + """ + global PYOP2_FINALIZED + if isinstance(comm, PETSc.Comm): + ispyop2comm = False + elif comm == MPI.COMM_NULL: + if PYOP2_FINALIZED == False: + # ~ import pytest; pytest.set_trace() + # ~ raise ValueError("COMM_NULL") + ispyop2comm = True + else: + ispyop2comm = True + elif isinstance(comm, MPI.Comm): + ispyop2comm = bool(comm.Get_attr(refcount_keyval)) else: - # Check if communicator has an embedded PyOP2 comm. - comm_out = comm_in.Get_attr(innercomm_keyval) - if comm_out is None: - # Haven't seen this comm before, duplicate it. - comm_out = comm_in.Dup() - comm_in.Set_attr(innercomm_keyval, comm_out) - comm_out.Set_attr(outercomm_keyval, comm_in) - # Refcount - comm_out.Set_attr(refcount_keyval, [1]) - # Remember we need to destroy it. - dupped_comms.append(comm_out) + raise ValueError("Argument passed to is_pyop2_comm() is not a recognised comm type") + return ispyop2comm + + +def pyop2_comm_status(): + """ Prints the reference counts for all comms PyOP2 has duplicated + """ + print('PYOP2 Communicator reference counts:') + print('| Communicator name | Count |') + print('==================================================') + for comm in dupped_comms: + if comm == MPI.COMM_NULL: + null = 'COMM_NULL' + print(f'| {null:39}| {0:5d} |') else: - refcount = comm_out.Get_attr(refcount_keyval) + refcount = comm.Get_attr(refcount_keyval)[0] if refcount is None: - raise ValueError("Inner comm without a refcount") - refcount[0] += 1 + refcount = -999 + print(f'| {comm.name:39}| {refcount:5d} |') + + +class PyOP2Comm: + """ Suitable for using a PyOP2 internal communicator suitably + incrementing and decrementing the comm. + """ + def __init__(self, comm): + self.comm = comm + self._comm = None + + def __enter__(self): + self._comm = internal_comm(self.comm) + return self._comm + + def __exit__(self, exc_type, exc_value, traceback): + decref(self._comm) + self._comm = None + + +def internal_comm(comm): + """ Creates an internal comm from the comm passed in + This happens on nearly every PyOP2 object so this avoids unnecessary + repetition. + :arg comm: A communicator or None + + :returns pyop2_comm: A PyOP2 internal communicator + """ + if comm is None: + # None will be the default when creating most objects + pyop2_comm = dup_comm(COMM_WORLD) + elif is_pyop2_comm(comm): + # Increase the reference count and return same comm if + # already an internal communicator + incref(comm) + pyop2_comm = comm + elif isinstance(comm, PETSc.Comm): + # Convert PETSc.Comm to mpi4py.MPI.Comm + comm = dup_comm(comm.tompi4py()) + pyop2_comm.Set_name(f"PYOP2_{comm.name or id(comm)}") + elif comm == MPI.COMM_NULL: + # Ensure comm is not the NULL communicator + raise ValueError("MPI_COMM_NULL passed to internal_comm()") + elif not isinstance(comm, MPI.Comm): + # If it is not an MPI.Comm raise error + raise ValueError("Don't know how to dup a %r" % type(comm)) + else: + pyop2_comm = dup_comm(comm) + return pyop2_comm + + +def incref(comm): + """ Increment communicator reference count + """ + assert is_pyop2_comm(comm) + refcount = comm.Get_attr(refcount_keyval) + refcount[0] += 1 + debug(f'{comm.name} INCREF to {refcount[0]}') + + +def decref(comm): + """ Decrement communicator reference count + """ + if comm == MPI.COMM_NULL: + comm = FriendlyCommNull() + assert is_pyop2_comm(comm) + # ~ if not PYOP2_FINALIZED: + refcount = comm.Get_attr(refcount_keyval) + refcount[0] -= 1 + debug(f'{comm.name} DECREF to {refcount[0]}') + if refcount[0] == 0: + dupped_comms.remove(comm) + debug(f'Freeing {comm.name}') + free_comm(comm) + + +def dup_comm(comm_in): + """Given a communicator return a communicator for internal use. + + :arg comm_in: Communicator to duplicate + + :returns: An mpi4py communicator.""" + assert not is_pyop2_comm(comm_in) + + # Check if communicator has an embedded PyOP2 comm. + comm_out = comm_in.Get_attr(innercomm_keyval) + if comm_out is None: + # Haven't seen this comm before, duplicate it. + comm_out = comm_in.Dup() + comm_in.Set_attr(innercomm_keyval, comm_out) + comm_out.Set_attr(outercomm_keyval, comm_in) + # Name + comm_out.Set_name(f"{comm_in.name or id(comm_in)}_DUP") + # Refcount + comm_out.Set_attr(refcount_keyval, [0]) + incref(comm_out) + # Remember we need to destroy it. + dupped_comms.append(comm_out) + elif is_pyop2_comm(comm_out): + # Inner comm is a PyOP2 comm, return it + incref(comm_out) + else: + raise ValueError("Inner comm is not a PyOP2 comm") return comm_out -# Comm used for compilation, stashed on the internal communicator -compilationcomm_keyval = MPI.Comm.Create_keyval() +@collective +def create_split_comm(comm): + if MPI.VERSION >= 3: + debug("Creating compilation communicator using MPI_Split_type") + split_comm = comm.Split_type(MPI.COMM_TYPE_SHARED) + debug("Finished creating compilation communicator using MPI_Split_type") + else: + debug("Creating compilation communicator using MPI_Split + filesystem") + import tempfile + if comm.rank == 0: + if not os.path.exists(configuration["cache_dir"]): + os.makedirs(configuration["cache_dir"], exist_ok=True) + tmpname = tempfile.mkdtemp(prefix="rank-determination-", + dir=configuration["cache_dir"]) + else: + tmpname = None + tmpname = comm.bcast(tmpname, root=0) + if tmpname is None: + raise CompilationError("Cannot determine sharedness of filesystem") + # Touch file + debug("Made tmpdir %s" % tmpname) + with open(os.path.join(tmpname, str(comm.rank)), "wb"): + pass + comm.barrier() + import glob + ranks = sorted(int(os.path.basename(name)) + for name in glob.glob("%s/[0-9]*" % tmpname)) + debug("Creating compilation communicator using filesystem colors") + split_comm = comm.Split(color=min(ranks), key=comm.rank) + debug("Finished creating compilation communicator using filesystem colors") + # Name + split_comm.Set_name(f"{comm.name or id(comm)}_COMPILATION") + # Refcount + split_comm.Set_attr(refcount_keyval, [0]) + incref(split_comm) + return split_comm def get_compilation_comm(comm): @@ -171,10 +339,59 @@ def get_compilation_comm(comm): def set_compilation_comm(comm, inner): - comm.Set_attr(compilationcomm_keyval, inner) + """Set the compilation communicator. + + :arg comm: A PyOP2 Communicator + :arg inner: The compilation communicator + """ + # Ensure `comm` is a PyOP2 comm + if not is_pyop2_comm(comm): + raise ValueError("Compilation communicator must be stashed on a PyOP2 comm") + + # Check if the compilation communicator is already set + old_inner = comm.Get_attr(compilationcomm_keyval) + if old_inner is not None: + if is_pyop2_comm(old_inner): + raise ValueError("Compilation communicator is not a PyOP2 comm, something is very broken!") + else: + decref(old_inner) + if not is_pyop2_comm(inner): + raise ValueError( + "Communicator used for compilation communicator must be a PyOP2 communicator.\n" + "Use pyop2.mpi.dup_comm() to create a PyOP2 comm from an existing comm.") + else: + # Stash `inner` as an attribute on `comm` + comm.Set_attr(compilationcomm_keyval, inner) + + +@collective +def compilation_comm(comm): + """Get a communicator for compilation. -def free_comm(comm, remove=True): + :arg comm: The input communicator, must be a PyOP2 comm. + :returns: A communicator used for compilation (may be smaller) + """ + if not is_pyop2_comm(comm): + raise ValueError("Compilation communicator is not a PyOP2 comm") + # Should we try and do node-local compilation? + if configuration["node_local_compilation"]: + retcomm = get_compilation_comm(comm) + if retcomm is not None: + debug("Found existing compilation communicator") + else: + retcomm = create_split_comm(comm) + set_compilation_comm(comm, retcomm) + # Add to list of known duplicated comms + debug(f"Appending compiler comm {retcomm.name} to list of comms") + dupped_comms.append(retcomm) + else: + retcomm = comm + incref(retcomm) + return retcomm + + +def free_comm(comm): """Free an internal communicator. :arg comm: The communicator to free. @@ -183,21 +400,18 @@ def free_comm(comm, remove=True): This only actually calls MPI_Comm_free once the refcount drops to zero. """ - if comm == MPI.COMM_NULL: - return - refcount = comm.Get_attr(refcount_keyval) - if refcount is None: - # Not a PyOP2 communicator, check for an embedded comm. - comm = comm.Get_attr(innercomm_keyval) - if comm is None: - raise ValueError("Trying to destroy communicator not known to PyOP2") - refcount = comm.Get_attr(refcount_keyval) - if refcount is None: - raise ValueError("Inner comm without a refcount") + if comm != MPI.COMM_NULL: + assert is_pyop2_comm(comm) + # ~ if is_pyop2_comm(comm): + # ~ # Not a PyOP2 communicator, check for an embedded comm. + # ~ comm = comm.Get_attr(innercomm_keyval) + # ~ if comm is None: + # ~ raise ValueError("Trying to destroy communicator not known to PyOP2") + # ~ if not is_pyop2_comm(comm): + # ~ raise ValueError("Inner comm is not a PyOP2 comm") + + # ~ decref(comm) - refcount[0] -= 1 - - if refcount[0] == 0: ocomm = comm.Get_attr(outercomm_keyval) if ocomm is not None: icomm = ocomm.Get_attr(innercomm_keyval) @@ -206,23 +420,43 @@ def free_comm(comm, remove=True): else: ocomm.Delete_attr(innercomm_keyval) del icomm - if remove: - # Only do this if not called from free_comms. + try: dupped_comms.remove(comm) + except ValueError: + debug(f"{comm.name} is not in list of known comms, probably already freed") + debug(f"Known comms are {[d.name for d in dupped_comms if d != MPI.COMM_NULL]}") compilation_comm = get_compilation_comm(comm) - if compilation_comm is not None: - compilation_comm.Free() + if compilation_comm == MPI.COMM_NULL: + comm.Delete_attr(compilationcomm_keyval) + elif compilation_comm is not None: + free_comm(compilation_comm) + comm.Delete_attr(compilationcomm_keyval) comm.Free() + else: + warning('Attempt to free MPI_COMM_NULL') @atexit.register def free_comms(): """Free all outstanding communicators.""" + # Collect garbage as it may hold on to communicator references + global PYOP2_FINALIZED + PYOP2_FINALIZED = True + debug("PyOP2 Finalizing") + debug("Calling gc.collect()") + import gc + gc.collect() + pyop2_comm_status() + print(dupped_comms) + debug(f"Freeing comms in list (length {len(dupped_comms)})") while dupped_comms: - c = dupped_comms.pop() - refcount = c.Get_attr(refcount_keyval) - for _ in range(refcount[0]): - free_comm(c, remove=False) + c = dupped_comms[-1] + if is_pyop2_comm(c): + refcount = c.Get_attr(refcount_keyval) + debug(f"Freeing {c.name}, which has refcount {refcount[0]}") + else: + debug("Freeing non PyOP2 comm in `free_comms()`") + free_comm(c) for kv in [refcount_keyval, innercomm_keyval, outercomm_keyval, @@ -232,19 +466,8 @@ def free_comms(): def hash_comm(comm): """Return a hashable identifier for a communicator.""" - # dup_comm returns a persistent internal communicator so we can - # use its id() as the hash since this is stable between invocations. - return id(dup_comm(comm)) - - -def collective(fn): - extra = trim(""" - This function is logically collective over MPI ranks, it is an - error to call it on fewer than all the ranks in MPI communicator. - """) - fn.__doc__ = "%s\n\n%s" % (trim(fn.__doc__), extra) if fn.__doc__ else extra - return fn - + assert is_pyop2_comm(comm) + return id(comm) # Install an exception hook to MPI Abort if an exception isn't caught # see: https://groups.google.com/d/msg/mpi4py/me2TFzHmmsQ/sSF99LE0t9QJ diff --git a/pyop2/op2.py b/pyop2/op2.py index 1fe7f9d8a..1a4c805d4 100644 --- a/pyop2/op2.py +++ b/pyop2/op2.py @@ -69,6 +69,9 @@ _initialised = False +# set the log level +print('PyOP2 log level:', configuration['log_level']) +set_log_level(configuration['log_level']) def initialised(): """Check whether PyOP2 has been yet initialised but not yet finalised.""" @@ -101,7 +104,7 @@ def init(**kwargs): configuration.reconfigure(**kwargs) set_log_level(configuration['log_level']) - + import pytest; pytest.set_trace() _initialised = True diff --git a/pyop2/parloop.py b/pyop2/parloop.py index 0ba340ee4..c35f21ec3 100644 --- a/pyop2/parloop.py +++ b/pyop2/parloop.py @@ -18,6 +18,7 @@ from pyop2.types import (Access, Global, AbstractDat, Dat, DatView, MixedDat, Mat, Set, MixedSet, ExtrudedSet, Subset, Map, ComposedMap, MixedMap) from pyop2.utils import cached_property +from pyop2.logger import debug class ParloopArg(abc.ABC): @@ -150,11 +151,14 @@ def __init__(self, global_knl, iterset, arguments): self.global_kernel = global_knl self.iterset = iterset + self.comm = mpi.internal_comm(iterset.comm) self.arguments, self.reduced_globals = self.prepare_reduced_globals(arguments, global_knl) + debug(f"INIT {self.__class__} and assign {self.comm.name}") - @property - def comm(self): - return self.iterset.comm + def __del__(self): + if hasattr(self, "comm"): + debug(f"DELETE {self.__class__} and removing reference to {self.comm.name}") + mpi.decref(self.comm) @property def local_kernel(self): diff --git a/pyop2/sparsity.pyx b/pyop2/sparsity.pyx index 0f327e3db..282ec042d 100644 --- a/pyop2/sparsity.pyx +++ b/pyop2/sparsity.pyx @@ -124,7 +124,7 @@ def build_sparsity(sparsity): nest = sparsity.nested if mixed and sparsity.nested: raise ValueError("Can't build sparsity on mixed nest, build the sparsity on the blocks") - preallocator = PETSc.Mat().create(comm=sparsity.comm) + preallocator = PETSc.Mat().create(comm=sparsity.comm.ob_mpi) preallocator.setType(PETSc.Mat.Type.PREALLOCATOR) if mixed: # Sparsity is the dof sparsity. diff --git a/pyop2/types/dat.py b/pyop2/types/dat.py index 03df1937b..11580f3cd 100644 --- a/pyop2/types/dat.py +++ b/pyop2/types/dat.py @@ -19,6 +19,7 @@ from pyop2.types.dataset import DataSet, GlobalDataSet, MixedDataSet from pyop2.types.data_carrier import DataCarrier, EmptyDataMixin, VecAccessMixin from pyop2.types.set import ExtrudedSet, GlobalSet, Set +from pyop2.logger import debug class AbstractDat(DataCarrier, EmptyDataMixin, abc.ABC): @@ -81,9 +82,15 @@ def __init__(self, dataset, data=None, dtype=None, name=None): EmptyDataMixin.__init__(self, data, dtype, self._shape) self._dataset = dataset - self.comm = dataset.comm + self.comm = mpi.internal_comm(dataset.comm) self.halo_valid = True self._name = name or "dat_#x%x" % id(self) + debug(f"INIT {self.__class__} and assign {self.comm.name}") + + def __del__(self): + if hasattr(self, "comm"): + debug(f"DELETE {self.__class__} and removing reference to {self.comm.name}") + mpi.decref(self.comm) self._halo_frozen = False self._frozen_access_mode = None @@ -768,7 +775,8 @@ def what(x): if not all(d.dtype == self._dats[0].dtype for d in self._dats): raise ex.DataValueError('MixedDat with different dtypes is not supported') # TODO: Think about different communicators on dats (c.f. MixedSet) - self.comm = self._dats[0].comm + self.comm = mpi.internal_comm(self._dats[0].comm) + debug(f"INIT {self.__class__} and assign {self.comm.name}") @property def dat_version(self): diff --git a/pyop2/types/data_carrier.py b/pyop2/types/data_carrier.py index 73d3974c2..fcf5f95f1 100644 --- a/pyop2/types/data_carrier.py +++ b/pyop2/types/data_carrier.py @@ -64,6 +64,7 @@ def __init__(self, data, dtype, shape): self._dtype = self._data.dtype @utils.cached_property + # ~ @property def _data(self): """Return the user-provided data buffer, or a zeroed buffer of the correct size if none was provided.""" diff --git a/pyop2/types/dataset.py b/pyop2/types/dataset.py index 635b130e3..0437f7e63 100644 --- a/pyop2/types/dataset.py +++ b/pyop2/types/dataset.py @@ -11,6 +11,7 @@ utils ) from pyop2.types.set import ExtrudedSet, GlobalSet, MixedSet, Set, Subset +from pyop2.logger import debug class DataSet(caching.ObjectCached): @@ -29,11 +30,19 @@ def __init__(self, iter_set, dim=1, name=None): return if isinstance(iter_set, Subset): raise NotImplementedError("Deriving a DataSet from a Subset is unsupported") + self.comm = mpi.internal_comm(iter_set.comm) self._set = iter_set self._dim = utils.as_tuple(dim, numbers.Integral) self._cdim = np.prod(self._dim).item() self._name = name or "dset_#x%x" % id(self) self._initialized = True + debug(f"INIT {self.__class__} and assign {self.comm.name}") + + def __del__(self): + # ~ if hasattr(self, "comm"): + if "comm" in self.__dict__: + debug(f"DELETE {self.__class__} and removing reference to {self.comm.name}") + mpi.decref(self.comm) @classmethod def _process_args(cls, *args, **kwargs): @@ -59,7 +68,6 @@ def __setstate__(self, d): def __getattr__(self, name): """Returns a Set specific attribute.""" value = getattr(self.set, name) - setattr(self, name, value) return value def __getitem__(self, idx): @@ -202,10 +210,14 @@ class GlobalDataSet(DataSet): def __init__(self, global_): """ :param global_: The :class:`Global` on which this object is based.""" - + if self._initialized: + return self._global = global_ + self.comm = mpi.internal_comm(global_.comm) self._globalset = GlobalSet(comm=self.comm) self._name = "gdset_#x%x" % id(self) + self._initialized = True + debug(f"INIT {self.__class__} and assign {self.comm.name}") @classmethod def _cache_key(cls, *args): @@ -227,11 +239,6 @@ def name(self): """Returns the name of the data set.""" return self._global._name - @utils.cached_property - def comm(self): - """Return the communicator on which the set is defined.""" - return self._global.comm - @utils.cached_property def set(self): """Returns the parent set of the data set.""" @@ -371,7 +378,14 @@ def __init__(self, arg, dims=None): if self._initialized: return self._dsets = arg + try: + # Try/except may not be necessary, someone needs to think about this... + comm = self._process_args(arg, dims)[0][0].comm + except AttributeError: + comm = None + self.comm = mpi.internal_comm(comm) self._initialized = True + debug(f"INIT {self.__class__} and assign {self.comm.name}") @classmethod def _process_args(cls, arg, dims=None): diff --git a/pyop2/types/glob.py b/pyop2/types/glob.py index dd5a609a8..7e31efe63 100644 --- a/pyop2/types/glob.py +++ b/pyop2/types/glob.py @@ -13,6 +13,7 @@ from pyop2.types.access import Access from pyop2.types.dataset import GlobalDataSet from pyop2.types.data_carrier import DataCarrier, EmptyDataMixin, VecAccessMixin +from pyop2.logger import debug class Global(DataCarrier, EmptyDataMixin, VecAccessMixin): @@ -39,21 +40,30 @@ class Global(DataCarrier, EmptyDataMixin, VecAccessMixin): @utils.validate_type(('name', str, ex.NameTypeError)) def __init__(self, dim, data=None, dtype=None, name=None, comm=None): + debug(f"calling Global.__init__") if isinstance(dim, Global): # If g is a Global, Global(g) performs a deep copy. This is for compatibility with Dat. self.__init__(dim._dim, None, dtype=dim.dtype, name="copy_of_%s" % dim.name, comm=dim.comm) dim.copy(self) - return - self._dim = utils.as_tuple(dim, int) - self._cdim = np.prod(self._dim).item() - EmptyDataMixin.__init__(self, data, dtype, self._dim) - self._buf = np.empty(self.shape, dtype=self.dtype) - self._name = name or "global_#x%x" % id(self) - self.comm = comm - # Object versioning setup - petsc_counter = (self.comm and self.dtype == PETSc.ScalarType) - VecAccessMixin.__init__(self, petsc_counter=petsc_counter) + else: + self._dim = utils.as_tuple(dim, int) + self._cdim = np.prod(self._dim).item() + EmptyDataMixin.__init__(self, data, dtype, self._dim) + self._buf = np.empty(self.shape, dtype=self.dtype) + self._name = name or "global_#x%x" % id(self) + # ~ import pdb; pdb.set_trace() + self.comm = mpi.internal_comm(comm) + # Object versioning setup + # ~ petsc_counter = (self.comm and self.dtype == PETSc.ScalarType) + petsc_counter = (comm and self.dtype == PETSc.ScalarType) + VecAccessMixin.__init__(self, petsc_counter=petsc_counter) + debug(f"INIT {self.__class__} and assign {self.comm.name}") + + def __del__(self): + if hasattr(self, "comm"): + debug(f"DELETE {self.__class__} and removing reference to {self.comm.name}") + mpi.decref(self.comm) @utils.cached_property def _kernel_args_(self): @@ -96,7 +106,8 @@ def __repr__(self): return "Global(%r, %r, %r, %r)" % (self._dim, self._data, self._data.dtype, self._name) - @utils.cached_property + # ~ @utils.cached_property + @property def dataset(self): return GlobalDataSet(self) @@ -281,7 +292,8 @@ def inner(self, other): assert isinstance(other, Global) return np.dot(self.data_ro, np.conj(other.data_ro)) - @utils.cached_property + # ~ @utils.cached_property + @property def _vec(self): assert self.dtype == PETSc.ScalarType, \ "Can't create Vec with type %s, must be %s" % (self.dtype, PETSc.ScalarType) diff --git a/pyop2/types/map.py b/pyop2/types/map.py index 7eedbdc50..516a9bd53 100644 --- a/pyop2/types/map.py +++ b/pyop2/types/map.py @@ -10,7 +10,9 @@ exceptions as ex, utils ) +from pyop2 import mpi from pyop2.types.set import GlobalSet, MixedSet, Set +from pyop2.logger import debug class Map: @@ -35,7 +37,7 @@ class Map: def __init__(self, iterset, toset, arity, values=None, name=None, offset=None, offset_quotient=None): self._iterset = iterset self._toset = toset - self.comm = toset.comm + self.comm = mpi.internal_comm(toset.comm) self._arity = arity self._values = utils.verify_reshape(values, dtypes.IntType, (iterset.total_size, arity), allow_none=True) @@ -51,6 +53,12 @@ def __init__(self, iterset, toset, arity, values=None, name=None, offset=None, o self._offset_quotient = utils.verify_reshape(offset_quotient, dtypes.IntType, (arity, )) # A cache for objects built on top of this map self._cache = {} + debug(f"INIT {self.__class__} and assign {self.comm.name}") + + def __del__(self): + if hasattr(self, "comm"): + debug(f"DELETE {self.__class__} and removing reference to {self.comm.name}") + mpi.decref(self.comm) @utils.cached_property def _kernel_args_(self): @@ -195,6 +203,7 @@ def __init__(self, map_, permutation): if isinstance(map_, ComposedMap): raise NotImplementedError("PermutedMap of ComposedMap not implemented: simply permute before composing") self.map_ = map_ + self.comm = mpi.internal_comm(map_.comm) self.permutation = np.asarray(permutation, dtype=Map.dtype) assert (np.unique(permutation) == np.arange(map_.arity, dtype=Map.dtype)).all() @@ -309,8 +318,9 @@ def __init__(self, maps): raise ex.MapTypeError("All maps needs to share a communicator") if len(comms) == 0: raise ex.MapTypeError("Don't know how to make communicator") - self.comm = comms[0] + self.comm = mpi.internal_comm(comms[0]) self._initialized = True + debug(f"INIT {self.__class__} and assign {self.comm.name}") @classmethod def _process_args(cls, *args, **kwargs): diff --git a/pyop2/types/mat.py b/pyop2/types/mat.py index 87e79c9e6..48bfd1e9d 100644 --- a/pyop2/types/mat.py +++ b/pyop2/types/mat.py @@ -20,6 +20,7 @@ from pyop2.types.dataset import DataSet, GlobalDataSet, MixedDataSet from pyop2.types.map import Map, ComposedMap from pyop2.types.set import MixedSet, Set, Subset +from pyop2.logger import debug class Sparsity(caching.ObjectCached): @@ -56,6 +57,7 @@ def __init__(self, dsets, maps, *, iteration_regions=None, name=None, nest=None, if self._initialized: return + debug(f"INIT {self.__class__} BEGIN") self._block_sparse = block_sparse # Split into a list of row maps and a list of column maps maps, iteration_regions = zip(*maps) @@ -68,11 +70,11 @@ def __init__(self, dsets, maps, *, iteration_regions=None, name=None, nest=None, self._o_nnz = None self._nrows = None if isinstance(dsets[0], GlobalDataSet) else self._rmaps[0].toset.size self._ncols = None if isinstance(dsets[1], GlobalDataSet) else self._cmaps[0].toset.size - self.lcomm = dsets[0].comm if isinstance(dsets[0], GlobalDataSet) else self._rmaps[0].comm - self.rcomm = dsets[1].comm if isinstance(dsets[1], GlobalDataSet) else self._cmaps[0].comm + self.lcomm = mpi.internal_comm(dsets[0].comm if isinstance(dsets[0], GlobalDataSet) else self._rmaps[0].comm) + self.rcomm = mpi.internal_comm(dsets[1].comm if isinstance(dsets[1], GlobalDataSet) else self._cmaps[0].comm) else: - self.lcomm = self._rmaps[0].comm - self.rcomm = self._cmaps[0].comm + self.lcomm = mpi.internal_comm(self._rmaps[0].comm) + self.rcomm = mpi.internal_comm(self._cmaps[0].comm) rset, cset = self.dsets # All rmaps and cmaps have the same data set - just use the first. @@ -93,10 +95,8 @@ def __init__(self, dsets, maps, *, iteration_regions=None, name=None, nest=None, if self.lcomm != self.rcomm: raise ValueError("Haven't thought hard enough about different left and right communicators") - self.comm = self.lcomm - + self.comm = mpi.internal_comm(self.lcomm) self._name = name or "sparsity_#x%x" % id(self) - self.iteration_regions = iteration_regions # If the Sparsity is defined on MixedDataSets, we need to build each # block separately @@ -130,6 +130,16 @@ def __init__(self, dsets, maps, *, iteration_regions=None, name=None, nest=None, self._o_nnz = onnz self._blocks = [[self]] self._initialized = True + debug(f"INIT {self.__class__} and assign {self.comm.name}") + + def __del__(self): + if hasattr(self, "comm"): + debug(f"DELETE {self.__class__} and removing reference to {self.comm.name}") + mpi.decref(self.comm) + if hasattr(self, "lcomm"): + mpi.decref(self.lcomm) + if hasattr(self, "rcomm"): + mpi.decref(self.rcomm) _cache = {} @@ -373,10 +383,18 @@ def __init__(self, parent, i, j): self._dims = tuple([tuple([parent.dims[i][j]])]) self._blocks = [[self]] self.iteration_regions = parent.iteration_regions - self.lcomm = self.dsets[0].comm - self.rcomm = self.dsets[1].comm + self.lcomm = mpi.internal_comm(self.dsets[0].comm) + self.rcomm = mpi.internal_comm(self.dsets[1].comm) # TODO: think about lcomm != rcomm - self.comm = self.lcomm + self.comm = mpi.internal_comm(self.lcomm) + + def __del__(self): + if hasattr(self, "comm"): + mpi.decref(self.comm) + if hasattr(self, "lcomm"): + mpi.decref(self.lcomm) + if hasattr(self, "rcomm"): + mpi.decref(self.rcomm) @classmethod def _process_args(cls, *args, **kwargs): @@ -434,13 +452,23 @@ class AbstractMat(DataCarrier, abc.ABC): ('name', str, ex.NameTypeError)) def __init__(self, sparsity, dtype=None, name=None): self._sparsity = sparsity - self.lcomm = sparsity.lcomm - self.rcomm = sparsity.rcomm - self.comm = sparsity.comm + self.lcomm = mpi.internal_comm(sparsity.lcomm) + self.rcomm = mpi.internal_comm(sparsity.rcomm) + self.comm = mpi.internal_comm(sparsity.comm) dtype = dtype or dtypes.ScalarType self._datatype = np.dtype(dtype) self._name = name or "mat_#x%x" % id(self) self.assembly_state = Mat.ASSEMBLED + debug(f"INIT {self.__class__} and assign {self.comm.name}") + + def __del__(self): + if hasattr(self, "comm"): + debug(f"DELETE {self.__class__} and removing reference to {self.comm.name}") + mpi.decref(self.comm) + if hasattr(self, "lcomm"): + mpi.decref(self.lcomm) + if hasattr(self, "rcomm"): + mpi.decref(self.rcomm) @utils.validate_in(('access', _modes, ex.ModeValueError)) def __call__(self, access, path, lgmaps=None, unroll_map=False): @@ -939,8 +967,9 @@ def __init__(self, parent, i, j): colis = cset.local_ises[j] self.handle = parent.handle.getLocalSubMatrix(isrow=rowis, iscol=colis) - self.comm = parent.comm + self.comm = mpi.internal_comm(parent.comm) self.local_to_global_maps = self.handle.getLGMap() + debug(f"INIT {self.__class__} and assign {self.comm.name}") @property def dat_version(self): @@ -1094,10 +1123,8 @@ def mult(self, mat, x, y): a[0] = x.array_r else: x.array_r - - comm = mpi.dup_comm(x.comm) - comm.bcast(a) - mpi.free_comm(comm) + with mpi.PyOP2Comm(x.comm) as comm: + comm.bcast(a) return y.scale(a) else: return v.pointwiseMult(x, y) @@ -1113,9 +1140,8 @@ def multTranspose(self, mat, x, y): a[0] = x.array_r else: x.array_r - comm = mpi.dup_comm(x.comm) - comm.bcast(a) - mpi.free_comm(comm) + with mpi.PyOP2Comm(x.comm) as comm: + comm.bcast(a) y.scale(a) else: v.pointwiseMult(x, y) @@ -1139,9 +1165,8 @@ def multTransposeAdd(self, mat, x, y, z): a[0] = x.array_r else: x.array_r - comm = mpi.dup_comm(x.comm) - comm.bcast(a) - mpi.free_comm(comm) + with mpi.PyOP2Comm(x.comm) as comm: + comm.bcast(a) if y == z: # Last two arguments are aliased. tmp = y.duplicate() diff --git a/pyop2/types/set.py b/pyop2/types/set.py index fed118b1c..bc605e02d 100644 --- a/pyop2/types/set.py +++ b/pyop2/types/set.py @@ -11,6 +11,7 @@ mpi, utils ) +from pyop2.logger import debug class Set: @@ -65,7 +66,7 @@ def _wrapper_cache_key_(self): @utils.validate_type(('size', (numbers.Integral, tuple, list, np.ndarray), ex.SizeTypeError), ('name', str, ex.NameTypeError)) def __init__(self, size, name=None, halo=None, comm=None): - self.comm = mpi.dup_comm(comm) + self.comm = mpi.internal_comm(comm) if isinstance(size, numbers.Integral): size = [size] * 3 size = utils.as_tuple(size, numbers.Integral, 3) @@ -77,6 +78,13 @@ def __init__(self, size, name=None, halo=None, comm=None): self._partition_size = 1024 # A cache of objects built on top of this set self._cache = {} + debug(f"INIT {self.__class__} and assign {self.comm.name}") + + def __del__(self): + # ~ if hasattr(self, "comm"): + if "comm" in self.__dict__: + debug(f"DELETE {self.__class__} and removing reference to {self.comm.name}") + mpi.decref(self.comm) @utils.cached_property def core_size(self): @@ -219,8 +227,11 @@ class GlobalSet(Set): _argtypes_ = () def __init__(self, comm=None): - self.comm = mpi.dup_comm(comm) + debug(f"calling GlobalSet.__init__") + # ~ import pdb; pdb.set_trace() + self.comm = mpi.internal_comm(comm) self._cache = {} + debug(f"INIT {self.__class__} and assign {self.comm.name}") @utils.cached_property def core_size(self): @@ -304,6 +315,7 @@ class ExtrudedSet(Set): @utils.validate_type(('parent', Set, TypeError)) def __init__(self, parent, layers, extruded_periodic=False): self._parent = parent + self.comm = mpi.internal_comm(parent.comm) try: layers = utils.verify_reshape(layers, dtypes.IntType, (parent.total_size, 2)) self.constant_layers = False @@ -325,6 +337,7 @@ def __init__(self, parent, layers, extruded_periodic=False): self._layers = layers self._extruded = True self._extruded_periodic = extruded_periodic + debug(f"INIT {self.__class__} and assign {self.comm.name}") @utils.cached_property def _kernel_args_(self): @@ -341,7 +354,6 @@ def _wrapper_cache_key_(self): def __getattr__(self, name): """Returns a :class:`Set` specific attribute.""" value = getattr(self._parent, name) - setattr(self, name, value) return value def __contains__(self, set): @@ -385,6 +397,8 @@ class Subset(ExtrudedSet): @utils.validate_type(('superset', Set, TypeError), ('indices', (list, tuple, np.ndarray), TypeError)) def __init__(self, superset, indices): + self.comm = mpi.internal_comm(superset.comm) + # sort and remove duplicates indices = np.unique(indices) if isinstance(superset, Subset): @@ -407,6 +421,7 @@ def __init__(self, superset, indices): len(self._indices)) self._extruded = superset._extruded self._extruded_periodic = superset._extruded_periodic + debug(f"INIT {self.__class__} and assign {self.comm.name}") @utils.cached_property def _kernel_args_(self): @@ -420,7 +435,6 @@ def _argtypes_(self): def __getattr__(self, name): """Returns a :class:`Set` specific attribute.""" value = getattr(self._superset, name) - setattr(self, name, value) return value def __pow__(self, e): @@ -528,8 +542,15 @@ def __init__(self, sets): assert all(s is None or isinstance(s, GlobalSet) or ((s.layers == self._sets[0].layers).all() if s.layers is not None else True) for s in sets), \ "All components of a MixedSet must have the same number of layers." # TODO: do all sets need the same communicator? - self.comm = functools.reduce(lambda a, b: a or b, map(lambda s: s if s is None else s.comm, sets)) + self.comm = mpi.internal_comm(functools.reduce(lambda a, b: a or b, map(lambda s: s if s is None else s.comm, sets))) self._initialized = True + debug(f"INIT {self.__class__} and assign {self.comm.name}") + + def __del__(self): + if self._initialized and hasattr(self, "comm"): + # ~ if "comm" in self.__dict__.keys(): + debug(f"DELETE {self.__class__} and removing reference to {self.comm.name}") + mpi.decref(self.comm) @utils.cached_property def _kernel_args_(self): diff --git a/test/unit/test_caching.py b/test/unit/test_caching.py index ff103bfd2..f175bc76f 100644 --- a/test/unit/test_caching.py +++ b/test/unit/test_caching.py @@ -540,10 +540,11 @@ def myfunc(arg): """Example function to cache the outputs of.""" return {arg} - @staticmethod - def collective_key(*args): + def collective_key(self, *args): """Return a cache key suitable for use when collective over a communicator.""" - return mpi.COMM_SELF, cachetools.keys.hashkey(*args) + # Explicitly `mpi.decref(self.comm)` in any test that uses this comm + self.comm = mpi.internal_comm(mpi.COMM_SELF) + return self.comm, cachetools.keys.hashkey(*args) @pytest.fixture def cache(cls): @@ -580,6 +581,7 @@ def test_decorator_collective_has_different_in_memory_key(self, cache, cachedir) assert obj1 == obj2 and obj1 is not obj2 assert len(cache) == 2 assert len(os.listdir(cachedir.name)) == 1 + mpi.decref(self.comm) def test_decorator_disk_cache_reuses_results(self, cache, cachedir): decorated_func = disk_cached(cache, cachedir.name)(self.myfunc) From fb274d22e7ed8284eac83e1dec9d4037a8c140dd Mon Sep 17 00:00:00 2001 From: JDBetteridge Date: Tue, 11 Oct 2022 13:51:43 +0100 Subject: [PATCH 03/17] This test was just wrong --- test/unit/test_matrices.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/unit/test_matrices.py b/test/unit/test_matrices.py index a84ea1aac..f66bac8f3 100644 --- a/test/unit/test_matrices.py +++ b/test/unit/test_matrices.py @@ -795,7 +795,6 @@ def test_mat_nbytes(self, mat): """Check that the matrix uses the amount of memory we expect.""" assert mat.nbytes == 14 * 8 - class TestMatrixStateChanges: """ @@ -822,7 +821,7 @@ def mat(self, request, msparsity, non_nest_mixed_sparsity): def test_mat_starts_assembled(self, mat): assert mat.assembly_state is op2.Mat.ASSEMBLED for m in mat: - assert mat.assembly_state is op2.Mat.ASSEMBLED + assert m.assembly_state is op2.Mat.ASSEMBLED def test_after_set_local_state_is_insert(self, mat): mat[0, 0].set_local_diagonal_entries([0]) From 35de348f7e1333cae253491549b6c644cd080b27 Mon Sep 17 00:00:00 2001 From: JDBetteridge Date: Thu, 13 Oct 2022 17:15:01 +0100 Subject: [PATCH 04/17] Tests pass with no comms referenced at end --- pyop2/types/mat.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/pyop2/types/mat.py b/pyop2/types/mat.py index 48bfd1e9d..b5a09e192 100644 --- a/pyop2/types/mat.py +++ b/pyop2/types/mat.py @@ -373,6 +373,11 @@ class SparsityBlock(Sparsity): This class only implements the properties necessary to infer its shape. It does not provide arrays of non zero fill.""" def __init__(self, parent, i, j): + # Protect against re-initialization when retrieved from cache + if self._initialized: + return + + debug(f"INIT {self.__class__} BEGIN") self._dsets = (parent.dsets[0][i], parent.dsets[1][j]) self._rmaps = tuple(m.split[i] for m in parent.rmaps) self._cmaps = tuple(m.split[j] for m in parent.cmaps) @@ -387,14 +392,8 @@ def __init__(self, parent, i, j): self.rcomm = mpi.internal_comm(self.dsets[1].comm) # TODO: think about lcomm != rcomm self.comm = mpi.internal_comm(self.lcomm) - - def __del__(self): - if hasattr(self, "comm"): - mpi.decref(self.comm) - if hasattr(self, "lcomm"): - mpi.decref(self.lcomm) - if hasattr(self, "rcomm"): - mpi.decref(self.rcomm) + self._initialized = True + debug(f"INIT {self.__class__} and assign {self.comm.name}") @classmethod def _process_args(cls, *args, **kwargs): @@ -958,6 +957,7 @@ class MatBlock(AbstractMat): :arg j: The block column. """ def __init__(self, parent, i, j): + debug(f"INIT {self.__class__} BEGIN") self._parent = parent self._i = i self._j = j From 51b21f4298df91fd170752ed1a67519b7e9f7a2d Mon Sep 17 00:00:00 2001 From: JDBetteridge Date: Tue, 18 Oct 2022 16:51:39 +0100 Subject: [PATCH 05/17] Lint code --- pyop2/caching.py | 5 +---- pyop2/compilation.py | 4 +--- pyop2/mpi.py | 34 +++++++++++++--------------------- pyop2/op2.py | 3 +-- pyop2/types/dataset.py | 2 +- pyop2/types/glob.py | 17 +++++++++-------- pyop2/types/set.py | 3 +-- test/unit/test_matrices.py | 1 + 8 files changed, 28 insertions(+), 41 deletions(-) diff --git a/pyop2/caching.py b/pyop2/caching.py index 28ee74a9a..24a3f5513 100644 --- a/pyop2/caching.py +++ b/pyop2/caching.py @@ -41,7 +41,7 @@ import cachetools from pyop2.configuration import configuration -from pyop2.mpi import hash_comm, is_pyop2_comm +from pyop2.mpi import hash_comm from pyop2.utils import cached_property @@ -274,9 +274,6 @@ def wrapper(*args, **kwargs): if collective: comm, disk_key = key(*args, **kwargs) disk_key = _as_hexdigest(disk_key) - # ~ k = id(comm), disk_key - # ~ if not is_pyop2_comm(comm): - # ~ import pytest; pytest.set_trace() k = hash_comm(comm), disk_key else: k = _as_hexdigest(key(*args, **kwargs)) diff --git a/pyop2/compilation.py b/pyop2/compilation.py index 831c775e8..0edb853cd 100644 --- a/pyop2/compilation.py +++ b/pyop2/compilation.py @@ -49,7 +49,7 @@ from pyop2.logger import warning, debug, progress, INFO from pyop2.exceptions import CompilationError from petsc4py import PETSc -from pyop2.logger import debug + def _check_hashes(x, y, datatype): """MPI reduction op to check if code hashes differ across ranks.""" @@ -408,8 +408,6 @@ def get_so(self, jitmodule, extension): # Atomically ensure soname exists os.rename(tmpname, soname) # Wait for compilation to complete - if self.comm == mpi.MPI.COMM_NULL: - import pytest; pytest.set_trace() self.comm.barrier() # Load resulting library return ctypes.CDLL(soname) diff --git a/pyop2/mpi.py b/pyop2/mpi.py index cb48efc60..265d4080b 100644 --- a/pyop2/mpi.py +++ b/pyop2/mpi.py @@ -37,13 +37,14 @@ from petsc4py import PETSc from mpi4py import MPI # noqa import atexit -import inspect # remove later +import os from pyop2.configuration import configuration -from pyop2.logger import warning, debug, progress, INFO +from pyop2.exceptions import CompilationError +from pyop2.logger import warning, debug from pyop2.utils import trim -__all__ = ("COMM_WORLD", "COMM_SELF", "MPI", "internal_comm", "is_pyop2_comm", "incref", "decref") +__all__ = ("COMM_WORLD", "COMM_SELF", "MPI", "internal_comm", "is_pyop2_comm", "incref", "decref", "PyOP2Comm") # These are user-level communicators, we never send any messages on # them inside PyOP2. @@ -163,16 +164,16 @@ def is_pyop2_comm(comm): if isinstance(comm, PETSc.Comm): ispyop2comm = False elif comm == MPI.COMM_NULL: - if PYOP2_FINALIZED == False: + if PYOP2_FINALIZED is False: # ~ import pytest; pytest.set_trace() # ~ raise ValueError("COMM_NULL") ispyop2comm = True else: ispyop2comm = True - elif isinstance(comm, MPI.Comm): + elif isinstance(comm, (MPI.Comm, FriendlyCommNull)): ispyop2comm = bool(comm.Get_attr(refcount_keyval)) else: - raise ValueError("Argument passed to is_pyop2_comm() is not a recognised comm type") + raise ValueError(f"Argument passed to is_pyop2_comm() is a {type(comm)}, which is not a recognised comm type") return ispyop2comm @@ -228,8 +229,7 @@ def internal_comm(comm): pyop2_comm = comm elif isinstance(comm, PETSc.Comm): # Convert PETSc.Comm to mpi4py.MPI.Comm - comm = dup_comm(comm.tompi4py()) - pyop2_comm.Set_name(f"PYOP2_{comm.name or id(comm)}") + pyop2_comm = dup_comm(comm.tompi4py()) elif comm == MPI.COMM_NULL: # Ensure comm is not the NULL communicator raise ValueError("MPI_COMM_NULL passed to internal_comm()") @@ -259,10 +259,10 @@ def decref(comm): # ~ if not PYOP2_FINALIZED: refcount = comm.Get_attr(refcount_keyval) refcount[0] -= 1 - debug(f'{comm.name} DECREF to {refcount[0]}') - if refcount[0] == 0: + # ~ debug(f'{comm.name} DECREF to {refcount[0]}') + if refcount[0] == 0 and not isinstance(comm, FriendlyCommNull): dupped_comms.remove(comm) - debug(f'Freeing {comm.name}') + # ~ debug(f'Freeing {comm.name}') free_comm(comm) @@ -282,6 +282,7 @@ def dup_comm(comm_in): comm_in.Set_attr(innercomm_keyval, comm_out) comm_out.Set_attr(outercomm_keyval, comm_in) # Name + # replace id() with .py2f() ??? comm_out.Set_name(f"{comm_in.name or id(comm_in)}_DUP") # Refcount comm_out.Set_attr(refcount_keyval, [0]) @@ -402,16 +403,6 @@ def free_comm(comm): """ if comm != MPI.COMM_NULL: assert is_pyop2_comm(comm) - # ~ if is_pyop2_comm(comm): - # ~ # Not a PyOP2 communicator, check for an embedded comm. - # ~ comm = comm.Get_attr(innercomm_keyval) - # ~ if comm is None: - # ~ raise ValueError("Trying to destroy communicator not known to PyOP2") - # ~ if not is_pyop2_comm(comm): - # ~ raise ValueError("Inner comm is not a PyOP2 comm") - - # ~ decref(comm) - ocomm = comm.Get_attr(outercomm_keyval) if ocomm is not None: icomm = ocomm.Get_attr(innercomm_keyval) @@ -469,6 +460,7 @@ def hash_comm(comm): assert is_pyop2_comm(comm) return id(comm) + # Install an exception hook to MPI Abort if an exception isn't caught # see: https://groups.google.com/d/msg/mpi4py/me2TFzHmmsQ/sSF99LE0t9QJ if COMM_WORLD.size > 1: diff --git a/pyop2/op2.py b/pyop2/op2.py index 1a4c805d4..726168e79 100644 --- a/pyop2/op2.py +++ b/pyop2/op2.py @@ -70,9 +70,9 @@ _initialised = False # set the log level -print('PyOP2 log level:', configuration['log_level']) set_log_level(configuration['log_level']) + def initialised(): """Check whether PyOP2 has been yet initialised but not yet finalised.""" return _initialised @@ -104,7 +104,6 @@ def init(**kwargs): configuration.reconfigure(**kwargs) set_log_level(configuration['log_level']) - import pytest; pytest.set_trace() _initialised = True diff --git a/pyop2/types/dataset.py b/pyop2/types/dataset.py index 0437f7e63..8191db4e9 100644 --- a/pyop2/types/dataset.py +++ b/pyop2/types/dataset.py @@ -41,7 +41,7 @@ def __init__(self, iter_set, dim=1, name=None): def __del__(self): # ~ if hasattr(self, "comm"): if "comm" in self.__dict__: - debug(f"DELETE {self.__class__} and removing reference to {self.comm.name}") + # ~ debug(f"DELETE {self.__class__} and removing reference to {self.comm.name}") mpi.decref(self.comm) @classmethod diff --git a/pyop2/types/glob.py b/pyop2/types/glob.py index 7e31efe63..464eb2d6f 100644 --- a/pyop2/types/glob.py +++ b/pyop2/types/glob.py @@ -40,7 +40,7 @@ class Global(DataCarrier, EmptyDataMixin, VecAccessMixin): @utils.validate_type(('name', str, ex.NameTypeError)) def __init__(self, dim, data=None, dtype=None, name=None, comm=None): - debug(f"calling Global.__init__") + debug("calling Global.__init__") if isinstance(dim, Global): # If g is a Global, Global(g) performs a deep copy. This is for compatibility with Dat. self.__init__(dim._dim, None, dtype=dim.dtype, @@ -52,16 +52,19 @@ def __init__(self, dim, data=None, dtype=None, name=None, comm=None): EmptyDataMixin.__init__(self, data, dtype, self._dim) self._buf = np.empty(self.shape, dtype=self.dtype) self._name = name or "global_#x%x" % id(self) - # ~ import pdb; pdb.set_trace() self.comm = mpi.internal_comm(comm) # Object versioning setup # ~ petsc_counter = (self.comm and self.dtype == PETSc.ScalarType) petsc_counter = (comm and self.dtype == PETSc.ScalarType) VecAccessMixin.__init__(self, petsc_counter=petsc_counter) - debug(f"INIT {self.__class__} and assign {self.comm.name}") + try: + name = self.comm.name + except AttributeError: + name = "None" + debug(f"INIT {self.__class__} and assign {name}") def __del__(self): - if hasattr(self, "comm"): + if hasattr(self, "comm") and self.comm is not None: debug(f"DELETE {self.__class__} and removing reference to {self.comm.name}") mpi.decref(self.comm) @@ -106,8 +109,7 @@ def __repr__(self): return "Global(%r, %r, %r, %r)" % (self._dim, self._data, self._data.dtype, self._name) - # ~ @utils.cached_property - @property + @utils.cached_property def dataset(self): return GlobalDataSet(self) @@ -292,8 +294,7 @@ def inner(self, other): assert isinstance(other, Global) return np.dot(self.data_ro, np.conj(other.data_ro)) - # ~ @utils.cached_property - @property + @utils.cached_property def _vec(self): assert self.dtype == PETSc.ScalarType, \ "Can't create Vec with type %s, must be %s" % (self.dtype, PETSc.ScalarType) diff --git a/pyop2/types/set.py b/pyop2/types/set.py index bc605e02d..25d3b17e9 100644 --- a/pyop2/types/set.py +++ b/pyop2/types/set.py @@ -227,7 +227,7 @@ class GlobalSet(Set): _argtypes_ = () def __init__(self, comm=None): - debug(f"calling GlobalSet.__init__") + debug("calling GlobalSet.__init__") # ~ import pdb; pdb.set_trace() self.comm = mpi.internal_comm(comm) self._cache = {} @@ -548,7 +548,6 @@ def __init__(self, sets): def __del__(self): if self._initialized and hasattr(self, "comm"): - # ~ if "comm" in self.__dict__.keys(): debug(f"DELETE {self.__class__} and removing reference to {self.comm.name}") mpi.decref(self.comm) diff --git a/test/unit/test_matrices.py b/test/unit/test_matrices.py index f66bac8f3..34b467e21 100644 --- a/test/unit/test_matrices.py +++ b/test/unit/test_matrices.py @@ -795,6 +795,7 @@ def test_mat_nbytes(self, mat): """Check that the matrix uses the amount of memory we expect.""" assert mat.nbytes == 14 * 8 + class TestMatrixStateChanges: """ From e3454bf65d360beb0cb0e834946d9122589f1da8 Mon Sep 17 00:00:00 2001 From: JDBetteridge Date: Tue, 18 Oct 2022 17:06:11 +0100 Subject: [PATCH 06/17] Remove debugging statements --- pyop2/compilation.py | 4 ---- pyop2/mpi.py | 5 +---- pyop2/parloop.py | 3 --- pyop2/sparsity.pyx | 2 +- pyop2/types/dat.py | 4 ---- pyop2/types/data_carrier.py | 1 - pyop2/types/dataset.py | 7 +------ pyop2/types/glob.py | 7 +------ pyop2/types/map.py | 4 ---- pyop2/types/mat.py | 10 ---------- pyop2/types/set.py | 12 +----------- 11 files changed, 5 insertions(+), 54 deletions(-) diff --git a/pyop2/compilation.py b/pyop2/compilation.py index 0edb853cd..2dad49e51 100644 --- a/pyop2/compilation.py +++ b/pyop2/compilation.py @@ -189,15 +189,11 @@ def __init__(self, extra_compiler_flags=(), extra_linker_flags=(), cpp=False, co self.pcomm = mpi.internal_comm(comm) self.comm = mpi.compilation_comm(self.pcomm) self.sniff_compiler_version() - debug(f"INIT {self.__class__} and assign {self.comm.name}") - debug(f"INIT {self.__class__} and assign {self.pcomm.name}") def __del__(self): if hasattr(self, "comm"): - debug(f"DELETE {self.__class__} and removing reference to {self.comm.name}") mpi.decref(self.comm) if hasattr(self, "pcomm"): - debug(f"DELETE {self.__class__} and removing reference to {self.pcomm.name}") mpi.decref(self.pcomm) def __repr__(self): diff --git a/pyop2/mpi.py b/pyop2/mpi.py index 265d4080b..7355c05cf 100644 --- a/pyop2/mpi.py +++ b/pyop2/mpi.py @@ -166,7 +166,7 @@ def is_pyop2_comm(comm): elif comm == MPI.COMM_NULL: if PYOP2_FINALIZED is False: # ~ import pytest; pytest.set_trace() - # ~ raise ValueError("COMM_NULL") + raise ValueError("COMM_NULL") ispyop2comm = True else: ispyop2comm = True @@ -247,7 +247,6 @@ def incref(comm): assert is_pyop2_comm(comm) refcount = comm.Get_attr(refcount_keyval) refcount[0] += 1 - debug(f'{comm.name} INCREF to {refcount[0]}') def decref(comm): @@ -259,10 +258,8 @@ def decref(comm): # ~ if not PYOP2_FINALIZED: refcount = comm.Get_attr(refcount_keyval) refcount[0] -= 1 - # ~ debug(f'{comm.name} DECREF to {refcount[0]}') if refcount[0] == 0 and not isinstance(comm, FriendlyCommNull): dupped_comms.remove(comm) - # ~ debug(f'Freeing {comm.name}') free_comm(comm) diff --git a/pyop2/parloop.py b/pyop2/parloop.py index c35f21ec3..6f4ad45e3 100644 --- a/pyop2/parloop.py +++ b/pyop2/parloop.py @@ -18,7 +18,6 @@ from pyop2.types import (Access, Global, AbstractDat, Dat, DatView, MixedDat, Mat, Set, MixedSet, ExtrudedSet, Subset, Map, ComposedMap, MixedMap) from pyop2.utils import cached_property -from pyop2.logger import debug class ParloopArg(abc.ABC): @@ -153,11 +152,9 @@ def __init__(self, global_knl, iterset, arguments): self.iterset = iterset self.comm = mpi.internal_comm(iterset.comm) self.arguments, self.reduced_globals = self.prepare_reduced_globals(arguments, global_knl) - debug(f"INIT {self.__class__} and assign {self.comm.name}") def __del__(self): if hasattr(self, "comm"): - debug(f"DELETE {self.__class__} and removing reference to {self.comm.name}") mpi.decref(self.comm) @property diff --git a/pyop2/sparsity.pyx b/pyop2/sparsity.pyx index 282ec042d..0f327e3db 100644 --- a/pyop2/sparsity.pyx +++ b/pyop2/sparsity.pyx @@ -124,7 +124,7 @@ def build_sparsity(sparsity): nest = sparsity.nested if mixed and sparsity.nested: raise ValueError("Can't build sparsity on mixed nest, build the sparsity on the blocks") - preallocator = PETSc.Mat().create(comm=sparsity.comm.ob_mpi) + preallocator = PETSc.Mat().create(comm=sparsity.comm) preallocator.setType(PETSc.Mat.Type.PREALLOCATOR) if mixed: # Sparsity is the dof sparsity. diff --git a/pyop2/types/dat.py b/pyop2/types/dat.py index 11580f3cd..7bd1195af 100644 --- a/pyop2/types/dat.py +++ b/pyop2/types/dat.py @@ -19,7 +19,6 @@ from pyop2.types.dataset import DataSet, GlobalDataSet, MixedDataSet from pyop2.types.data_carrier import DataCarrier, EmptyDataMixin, VecAccessMixin from pyop2.types.set import ExtrudedSet, GlobalSet, Set -from pyop2.logger import debug class AbstractDat(DataCarrier, EmptyDataMixin, abc.ABC): @@ -85,11 +84,9 @@ def __init__(self, dataset, data=None, dtype=None, name=None): self.comm = mpi.internal_comm(dataset.comm) self.halo_valid = True self._name = name or "dat_#x%x" % id(self) - debug(f"INIT {self.__class__} and assign {self.comm.name}") def __del__(self): if hasattr(self, "comm"): - debug(f"DELETE {self.__class__} and removing reference to {self.comm.name}") mpi.decref(self.comm) self._halo_frozen = False @@ -776,7 +773,6 @@ def what(x): raise ex.DataValueError('MixedDat with different dtypes is not supported') # TODO: Think about different communicators on dats (c.f. MixedSet) self.comm = mpi.internal_comm(self._dats[0].comm) - debug(f"INIT {self.__class__} and assign {self.comm.name}") @property def dat_version(self): diff --git a/pyop2/types/data_carrier.py b/pyop2/types/data_carrier.py index fcf5f95f1..73d3974c2 100644 --- a/pyop2/types/data_carrier.py +++ b/pyop2/types/data_carrier.py @@ -64,7 +64,6 @@ def __init__(self, data, dtype, shape): self._dtype = self._data.dtype @utils.cached_property - # ~ @property def _data(self): """Return the user-provided data buffer, or a zeroed buffer of the correct size if none was provided.""" diff --git a/pyop2/types/dataset.py b/pyop2/types/dataset.py index 8191db4e9..cbeb844fb 100644 --- a/pyop2/types/dataset.py +++ b/pyop2/types/dataset.py @@ -11,7 +11,6 @@ utils ) from pyop2.types.set import ExtrudedSet, GlobalSet, MixedSet, Set, Subset -from pyop2.logger import debug class DataSet(caching.ObjectCached): @@ -36,12 +35,10 @@ def __init__(self, iter_set, dim=1, name=None): self._cdim = np.prod(self._dim).item() self._name = name or "dset_#x%x" % id(self) self._initialized = True - debug(f"INIT {self.__class__} and assign {self.comm.name}") def __del__(self): - # ~ if hasattr(self, "comm"): + # Cannot use hasattr here if "comm" in self.__dict__: - # ~ debug(f"DELETE {self.__class__} and removing reference to {self.comm.name}") mpi.decref(self.comm) @classmethod @@ -217,7 +214,6 @@ def __init__(self, global_): self._globalset = GlobalSet(comm=self.comm) self._name = "gdset_#x%x" % id(self) self._initialized = True - debug(f"INIT {self.__class__} and assign {self.comm.name}") @classmethod def _cache_key(cls, *args): @@ -385,7 +381,6 @@ def __init__(self, arg, dims=None): comm = None self.comm = mpi.internal_comm(comm) self._initialized = True - debug(f"INIT {self.__class__} and assign {self.comm.name}") @classmethod def _process_args(cls, arg, dims=None): diff --git a/pyop2/types/glob.py b/pyop2/types/glob.py index 464eb2d6f..d40e2f37d 100644 --- a/pyop2/types/glob.py +++ b/pyop2/types/glob.py @@ -13,7 +13,6 @@ from pyop2.types.access import Access from pyop2.types.dataset import GlobalDataSet from pyop2.types.data_carrier import DataCarrier, EmptyDataMixin, VecAccessMixin -from pyop2.logger import debug class Global(DataCarrier, EmptyDataMixin, VecAccessMixin): @@ -40,7 +39,6 @@ class Global(DataCarrier, EmptyDataMixin, VecAccessMixin): @utils.validate_type(('name', str, ex.NameTypeError)) def __init__(self, dim, data=None, dtype=None, name=None, comm=None): - debug("calling Global.__init__") if isinstance(dim, Global): # If g is a Global, Global(g) performs a deep copy. This is for compatibility with Dat. self.__init__(dim._dim, None, dtype=dim.dtype, @@ -54,18 +52,15 @@ def __init__(self, dim, data=None, dtype=None, name=None, comm=None): self._name = name or "global_#x%x" % id(self) self.comm = mpi.internal_comm(comm) # Object versioning setup - # ~ petsc_counter = (self.comm and self.dtype == PETSc.ScalarType) petsc_counter = (comm and self.dtype == PETSc.ScalarType) VecAccessMixin.__init__(self, petsc_counter=petsc_counter) try: name = self.comm.name except AttributeError: name = "None" - debug(f"INIT {self.__class__} and assign {name}") def __del__(self): - if hasattr(self, "comm") and self.comm is not None: - debug(f"DELETE {self.__class__} and removing reference to {self.comm.name}") + if hasattr(self, "comm"): mpi.decref(self.comm) @utils.cached_property diff --git a/pyop2/types/map.py b/pyop2/types/map.py index 516a9bd53..4b632d4c8 100644 --- a/pyop2/types/map.py +++ b/pyop2/types/map.py @@ -12,7 +12,6 @@ ) from pyop2 import mpi from pyop2.types.set import GlobalSet, MixedSet, Set -from pyop2.logger import debug class Map: @@ -53,11 +52,9 @@ def __init__(self, iterset, toset, arity, values=None, name=None, offset=None, o self._offset_quotient = utils.verify_reshape(offset_quotient, dtypes.IntType, (arity, )) # A cache for objects built on top of this map self._cache = {} - debug(f"INIT {self.__class__} and assign {self.comm.name}") def __del__(self): if hasattr(self, "comm"): - debug(f"DELETE {self.__class__} and removing reference to {self.comm.name}") mpi.decref(self.comm) @utils.cached_property @@ -320,7 +317,6 @@ def __init__(self, maps): raise ex.MapTypeError("Don't know how to make communicator") self.comm = mpi.internal_comm(comms[0]) self._initialized = True - debug(f"INIT {self.__class__} and assign {self.comm.name}") @classmethod def _process_args(cls, *args, **kwargs): diff --git a/pyop2/types/mat.py b/pyop2/types/mat.py index b5a09e192..a3c65feef 100644 --- a/pyop2/types/mat.py +++ b/pyop2/types/mat.py @@ -20,7 +20,6 @@ from pyop2.types.dataset import DataSet, GlobalDataSet, MixedDataSet from pyop2.types.map import Map, ComposedMap from pyop2.types.set import MixedSet, Set, Subset -from pyop2.logger import debug class Sparsity(caching.ObjectCached): @@ -57,7 +56,6 @@ def __init__(self, dsets, maps, *, iteration_regions=None, name=None, nest=None, if self._initialized: return - debug(f"INIT {self.__class__} BEGIN") self._block_sparse = block_sparse # Split into a list of row maps and a list of column maps maps, iteration_regions = zip(*maps) @@ -130,11 +128,9 @@ def __init__(self, dsets, maps, *, iteration_regions=None, name=None, nest=None, self._o_nnz = onnz self._blocks = [[self]] self._initialized = True - debug(f"INIT {self.__class__} and assign {self.comm.name}") def __del__(self): if hasattr(self, "comm"): - debug(f"DELETE {self.__class__} and removing reference to {self.comm.name}") mpi.decref(self.comm) if hasattr(self, "lcomm"): mpi.decref(self.lcomm) @@ -377,7 +373,6 @@ def __init__(self, parent, i, j): if self._initialized: return - debug(f"INIT {self.__class__} BEGIN") self._dsets = (parent.dsets[0][i], parent.dsets[1][j]) self._rmaps = tuple(m.split[i] for m in parent.rmaps) self._cmaps = tuple(m.split[j] for m in parent.cmaps) @@ -393,7 +388,6 @@ def __init__(self, parent, i, j): # TODO: think about lcomm != rcomm self.comm = mpi.internal_comm(self.lcomm) self._initialized = True - debug(f"INIT {self.__class__} and assign {self.comm.name}") @classmethod def _process_args(cls, *args, **kwargs): @@ -458,11 +452,9 @@ def __init__(self, sparsity, dtype=None, name=None): self._datatype = np.dtype(dtype) self._name = name or "mat_#x%x" % id(self) self.assembly_state = Mat.ASSEMBLED - debug(f"INIT {self.__class__} and assign {self.comm.name}") def __del__(self): if hasattr(self, "comm"): - debug(f"DELETE {self.__class__} and removing reference to {self.comm.name}") mpi.decref(self.comm) if hasattr(self, "lcomm"): mpi.decref(self.lcomm) @@ -957,7 +949,6 @@ class MatBlock(AbstractMat): :arg j: The block column. """ def __init__(self, parent, i, j): - debug(f"INIT {self.__class__} BEGIN") self._parent = parent self._i = i self._j = j @@ -969,7 +960,6 @@ def __init__(self, parent, i, j): iscol=colis) self.comm = mpi.internal_comm(parent.comm) self.local_to_global_maps = self.handle.getLGMap() - debug(f"INIT {self.__class__} and assign {self.comm.name}") @property def dat_version(self): diff --git a/pyop2/types/set.py b/pyop2/types/set.py index 25d3b17e9..2615edd1a 100644 --- a/pyop2/types/set.py +++ b/pyop2/types/set.py @@ -11,7 +11,6 @@ mpi, utils ) -from pyop2.logger import debug class Set: @@ -78,12 +77,10 @@ def __init__(self, size, name=None, halo=None, comm=None): self._partition_size = 1024 # A cache of objects built on top of this set self._cache = {} - debug(f"INIT {self.__class__} and assign {self.comm.name}") def __del__(self): - # ~ if hasattr(self, "comm"): + # Cannot use hasattr here if "comm" in self.__dict__: - debug(f"DELETE {self.__class__} and removing reference to {self.comm.name}") mpi.decref(self.comm) @utils.cached_property @@ -227,11 +224,8 @@ class GlobalSet(Set): _argtypes_ = () def __init__(self, comm=None): - debug("calling GlobalSet.__init__") - # ~ import pdb; pdb.set_trace() self.comm = mpi.internal_comm(comm) self._cache = {} - debug(f"INIT {self.__class__} and assign {self.comm.name}") @utils.cached_property def core_size(self): @@ -337,7 +331,6 @@ def __init__(self, parent, layers, extruded_periodic=False): self._layers = layers self._extruded = True self._extruded_periodic = extruded_periodic - debug(f"INIT {self.__class__} and assign {self.comm.name}") @utils.cached_property def _kernel_args_(self): @@ -421,7 +414,6 @@ def __init__(self, superset, indices): len(self._indices)) self._extruded = superset._extruded self._extruded_periodic = superset._extruded_periodic - debug(f"INIT {self.__class__} and assign {self.comm.name}") @utils.cached_property def _kernel_args_(self): @@ -544,11 +536,9 @@ def __init__(self, sets): # TODO: do all sets need the same communicator? self.comm = mpi.internal_comm(functools.reduce(lambda a, b: a or b, map(lambda s: s if s is None else s.comm, sets))) self._initialized = True - debug(f"INIT {self.__class__} and assign {self.comm.name}") def __del__(self): if self._initialized and hasattr(self, "comm"): - debug(f"DELETE {self.__class__} and removing reference to {self.comm.name}") mpi.decref(self.comm) @utils.cached_property From b2520eeeea06220c0514218133176efa113590ce Mon Sep 17 00:00:00 2001 From: JDBetteridge Date: Thu, 20 Oct 2022 17:26:59 +0100 Subject: [PATCH 07/17] Fix up a few more MPI bits --- pyop2/mpi.py | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/pyop2/mpi.py b/pyop2/mpi.py index 7355c05cf..c4d1d764a 100644 --- a/pyop2/mpi.py +++ b/pyop2/mpi.py @@ -147,7 +147,14 @@ def __init__(self): self.name = 'PYOP2_FRIENDLY_COMM_NULL' def Get_attr(self, keyval): - return [1] + if keyval is refcount_keyval: + ret = [1] + elif keyval in (innercomm_keyval, outercomm_keyval, compilationcomm_keyval): + ret = None + return ret + + def Delete_attr(self, keyval): + pass def Free(self): pass @@ -255,11 +262,15 @@ def decref(comm): if comm == MPI.COMM_NULL: comm = FriendlyCommNull() assert is_pyop2_comm(comm) - # ~ if not PYOP2_FINALIZED: - refcount = comm.Get_attr(refcount_keyval) - refcount[0] -= 1 - if refcount[0] == 0 and not isinstance(comm, FriendlyCommNull): - dupped_comms.remove(comm) + if not PYOP2_FINALIZED: + refcount = comm.Get_attr(refcount_keyval) + refcount[0] -= 1 + if refcount[0] == 0 and not isinstance(comm, FriendlyCommNull): + dupped_comms.remove(comm) + free_comm(comm) + elif comm == MPI.COMM_NULL: + pass + else: free_comm(comm) @@ -279,7 +290,7 @@ def dup_comm(comm_in): comm_in.Set_attr(innercomm_keyval, comm_out) comm_out.Set_attr(outercomm_keyval, comm_in) # Name - # replace id() with .py2f() ??? + # TODO: replace id() with .py2f() ??? comm_out.Set_name(f"{comm_in.name or id(comm_in)}_DUP") # Refcount comm_out.Set_attr(refcount_keyval, [0]) @@ -398,9 +409,14 @@ def free_comm(comm): This only actually calls MPI_Comm_free once the refcount drops to zero. """ + # ~ if isinstance(comm, list): + # ~ import pytest; pytest.set_trace() if comm != MPI.COMM_NULL: assert is_pyop2_comm(comm) ocomm = comm.Get_attr(outercomm_keyval) + if isinstance(ocomm, list): + # No idea why this happens!? + ocomm = None if ocomm is not None: icomm = ocomm.Get_attr(innercomm_keyval) if icomm is None: From 07d1dc50ca9a37fcc0578484b7c38206d1e8c825 Mon Sep 17 00:00:00 2001 From: JDBetteridge Date: Mon, 24 Oct 2022 22:32:17 +0100 Subject: [PATCH 08/17] Fix deadlocks in Firedrake tests --- pyop2/mpi.py | 4 ++-- pyop2/parloop.py | 7 +++---- pyop2/types/dat.py | 6 +++--- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/pyop2/mpi.py b/pyop2/mpi.py index c4d1d764a..e84427859 100644 --- a/pyop2/mpi.py +++ b/pyop2/mpi.py @@ -261,8 +261,8 @@ def decref(comm): """ if comm == MPI.COMM_NULL: comm = FriendlyCommNull() - assert is_pyop2_comm(comm) if not PYOP2_FINALIZED: + assert is_pyop2_comm(comm) refcount = comm.Get_attr(refcount_keyval) refcount[0] -= 1 if refcount[0] == 0 and not isinstance(comm, FriendlyCommNull): @@ -416,7 +416,7 @@ def free_comm(comm): ocomm = comm.Get_attr(outercomm_keyval) if isinstance(ocomm, list): # No idea why this happens!? - ocomm = None + raise ValueError("Why have we got a list!?") if ocomm is not None: icomm = ocomm.Get_attr(innercomm_keyval) if icomm is None: diff --git a/pyop2/parloop.py b/pyop2/parloop.py index 6f4ad45e3..ac78e6bda 100644 --- a/pyop2/parloop.py +++ b/pyop2/parloop.py @@ -455,8 +455,7 @@ def _check_frozen_access_modes(cls, local_knl, arguments): "Dats with frozen halos must always be accessed with the same access mode" ) - @classmethod - def prepare_reduced_globals(cls, arguments, global_knl): + def prepare_reduced_globals(self, arguments, global_knl): """Swap any :class:`GlobalParloopArg` instances that are INC'd into with zeroed replacements. @@ -466,9 +465,9 @@ def prepare_reduced_globals(cls, arguments, global_knl): """ arguments = list(arguments) reduced_globals = {} - for i, (lk_arg, gk_arg, pl_arg) in enumerate(cls.zip_arguments(global_knl, arguments)): + for i, (lk_arg, gk_arg, pl_arg) in enumerate(self.zip_arguments(global_knl, arguments)): if isinstance(gk_arg, GlobalKernelArg) and lk_arg.access == Access.INC: - tmp = Global(gk_arg.dim, data=np.zeros_like(pl_arg.data.data_ro), dtype=lk_arg.dtype) + tmp = Global(gk_arg.dim, data=np.zeros_like(pl_arg.data.data_ro), dtype=lk_arg.dtype, comm=self.comm) reduced_globals[tmp] = pl_arg arguments[i] = GlobalParloopArg(tmp) diff --git a/pyop2/types/dat.py b/pyop2/types/dat.py index 7bd1195af..615a2f82c 100644 --- a/pyop2/types/dat.py +++ b/pyop2/types/dat.py @@ -85,13 +85,13 @@ def __init__(self, dataset, data=None, dtype=None, name=None): self.halo_valid = True self._name = name or "dat_#x%x" % id(self) + self._halo_frozen = False + self._frozen_access_mode = None + def __del__(self): if hasattr(self, "comm"): mpi.decref(self.comm) - self._halo_frozen = False - self._frozen_access_mode = None - @utils.cached_property def _kernel_args_(self): return (self._data.ctypes.data, ) From 5c4b242641d8b0e45a20331cd47a184c06f7c02e Mon Sep 17 00:00:00 2001 From: JDBetteridge Date: Wed, 26 Oct 2022 13:55:11 +0100 Subject: [PATCH 09/17] Comm in composed map was not internal --- pyop2/types/map.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyop2/types/map.py b/pyop2/types/map.py index 4b632d4c8..91224d52a 100644 --- a/pyop2/types/map.py +++ b/pyop2/types/map.py @@ -251,7 +251,7 @@ def __init__(self, *maps_, name=None): raise ex.MapTypeError("frommap.arity must be 1") self._iterset = maps_[-1].iterset self._toset = maps_[0].toset - self.comm = self._toset.comm + self.comm = mpi.internal_comm(self._toset.comm) self._arity = maps_[0].arity # Don't call super().__init__() to avoid calling verify_reshape() self._values = None From 390289f90f4af45672d70a24f9d7048306e18f34 Mon Sep 17 00:00:00 2001 From: JDBetteridge Date: Wed, 26 Oct 2022 14:34:22 +0100 Subject: [PATCH 10/17] Remove pyop2.mpi.FriendlyCommNull --- pyop2/mpi.py | 25 +++---------------------- 1 file changed, 3 insertions(+), 22 deletions(-) diff --git a/pyop2/mpi.py b/pyop2/mpi.py index e84427859..d8c59e350 100644 --- a/pyop2/mpi.py +++ b/pyop2/mpi.py @@ -142,24 +142,6 @@ def delcomm_outer(comm, keyval, icomm): dupped_comms = [] -class FriendlyCommNull: - def __init__(self): - self.name = 'PYOP2_FRIENDLY_COMM_NULL' - - def Get_attr(self, keyval): - if keyval is refcount_keyval: - ret = [1] - elif keyval in (innercomm_keyval, outercomm_keyval, compilationcomm_keyval): - ret = None - return ret - - def Delete_attr(self, keyval): - pass - - def Free(self): - pass - - def is_pyop2_comm(comm): """Returns `True` if `comm` is a PyOP2 communicator, False if `comm` another communicator. @@ -177,7 +159,7 @@ def is_pyop2_comm(comm): ispyop2comm = True else: ispyop2comm = True - elif isinstance(comm, (MPI.Comm, FriendlyCommNull)): + elif isinstance(comm, MPI.Comm): ispyop2comm = bool(comm.Get_attr(refcount_keyval)) else: raise ValueError(f"Argument passed to is_pyop2_comm() is a {type(comm)}, which is not a recognised comm type") @@ -259,13 +241,11 @@ def incref(comm): def decref(comm): """ Decrement communicator reference count """ - if comm == MPI.COMM_NULL: - comm = FriendlyCommNull() if not PYOP2_FINALIZED: assert is_pyop2_comm(comm) refcount = comm.Get_attr(refcount_keyval) refcount[0] -= 1 - if refcount[0] == 0 and not isinstance(comm, FriendlyCommNull): + if refcount[0] == 0: dupped_comms.remove(comm) free_comm(comm) elif comm == MPI.COMM_NULL: @@ -388,6 +368,7 @@ def compilation_comm(comm): retcomm = get_compilation_comm(comm) if retcomm is not None: debug("Found existing compilation communicator") + debug(f"{retcomm.name}") else: retcomm = create_split_comm(comm) set_compilation_comm(comm, retcomm) From ff6740739527bdbbc2ced15edb369ccc440b69a5 Mon Sep 17 00:00:00 2001 From: JDBetteridge Date: Thu, 27 Oct 2022 12:21:56 +0100 Subject: [PATCH 11/17] Address reviewer comments --- pyop2/compilation.py | 8 ++-- pyop2/mpi.py | 98 +++++++++++++++++++++++------------------- pyop2/types/dataset.py | 3 +- pyop2/types/glob.py | 4 -- pyop2/types/set.py | 3 +- 5 files changed, 61 insertions(+), 55 deletions(-) diff --git a/pyop2/compilation.py b/pyop2/compilation.py index 2dad49e51..32f743c2f 100644 --- a/pyop2/compilation.py +++ b/pyop2/compilation.py @@ -178,7 +178,9 @@ class Compiler(ABC): _debugflags = () def __init__(self, extra_compiler_flags=(), extra_linker_flags=(), cpp=False, comm=None): + # Get compiler version ASAP since it is used in __repr__ self.sniff_compiler_version() + self._extra_compiler_flags = tuple(extra_compiler_flags) self._extra_linker_flags = tuple(extra_linker_flags) @@ -188,7 +190,6 @@ def __init__(self, extra_compiler_flags=(), extra_linker_flags=(), cpp=False, co # Compilation communicators are reference counted on the PyOP2 comm self.pcomm = mpi.internal_comm(comm) self.comm = mpi.compilation_comm(self.pcomm) - self.sniff_compiler_version() def __del__(self): if hasattr(self, "comm"): @@ -597,9 +598,8 @@ def __init__(self, code, argtypes): else: exe = configuration["cc"] or "mpicc" compiler = sniff_compiler(exe) - x = compiler(cppargs, ldargs, cpp=cpp, comm=comm) - dll = x.get_so(code, extension) - del x + dll = compiler(cppargs, ldargs, cpp=cpp, comm=comm).get_so(code, extension) + if isinstance(jitmodule, GlobalKernel): _add_profiling_events(dll, code.local_kernel.events) diff --git a/pyop2/mpi.py b/pyop2/mpi.py index d8c59e350..4e17d3768 100644 --- a/pyop2/mpi.py +++ b/pyop2/mpi.py @@ -154,7 +154,6 @@ def is_pyop2_comm(comm): ispyop2comm = False elif comm == MPI.COMM_NULL: if PYOP2_FINALIZED is False: - # ~ import pytest; pytest.set_trace() raise ValueError("COMM_NULL") ispyop2comm = True else: @@ -184,26 +183,30 @@ def pyop2_comm_status(): class PyOP2Comm: - """ Suitable for using a PyOP2 internal communicator suitably - incrementing and decrementing the comm. + """ Use a PyOP2 internal communicator and + increment and decrement the internal comm. + :arg comm: Any communicator """ def __init__(self, comm): - self.comm = comm - self._comm = None + self.user_comm = comm + self.internal_comm = None def __enter__(self): - self._comm = internal_comm(self.comm) - return self._comm + """ Returns an internal comm tat will be safely decref'd + when leaving the context manager + + :returns pyop2_comm: A PyOP2 internal communicator + """ + self.internal_comm = internal_comm(self.user_comm) + return self.internal_comm def __exit__(self, exc_type, exc_value, traceback): - decref(self._comm) - self._comm = None + decref(self.internal_comm) + self.internal_comm = None def internal_comm(comm): - """ Creates an internal comm from the comm passed in - This happens on nearly every PyOP2 object so this avoids unnecessary - repetition. + """ Creates an internal comm from the user comm :arg comm: A communicator or None :returns pyop2_comm: A PyOP2 internal communicator @@ -223,7 +226,6 @@ def internal_comm(comm): # Ensure comm is not the NULL communicator raise ValueError("MPI_COMM_NULL passed to internal_comm()") elif not isinstance(comm, MPI.Comm): - # If it is not an MPI.Comm raise error raise ValueError("Don't know how to dup a %r" % type(comm)) else: pyop2_comm = dup_comm(comm) @@ -241,6 +243,7 @@ def incref(comm): def decref(comm): """ Decrement communicator reference count """ + global PYOP2_FINALIZED if not PYOP2_FINALIZED: assert is_pyop2_comm(comm) refcount = comm.Get_attr(refcount_keyval) @@ -259,34 +262,41 @@ def dup_comm(comm_in): :arg comm_in: Communicator to duplicate - :returns: An mpi4py communicator.""" + :returns internal_comm: An internal (PyOP2) communicator.""" assert not is_pyop2_comm(comm_in) # Check if communicator has an embedded PyOP2 comm. - comm_out = comm_in.Get_attr(innercomm_keyval) - if comm_out is None: + internal_comm = comm_in.Get_attr(innercomm_keyval) + if internal_comm is None: # Haven't seen this comm before, duplicate it. - comm_out = comm_in.Dup() - comm_in.Set_attr(innercomm_keyval, comm_out) - comm_out.Set_attr(outercomm_keyval, comm_in) + internal_comm = comm_in.Dup() + comm_in.Set_attr(innercomm_keyval, internal_comm) + internal_comm.Set_attr(outercomm_keyval, comm_in) # Name - # TODO: replace id() with .py2f() ??? - comm_out.Set_name(f"{comm_in.name or id(comm_in)}_DUP") + internal_comm.Set_name(f"{comm_in.name or comm_in.py2f()}_DUP") # Refcount - comm_out.Set_attr(refcount_keyval, [0]) - incref(comm_out) + internal_comm.Set_attr(refcount_keyval, [0]) + incref(internal_comm) # Remember we need to destroy it. - dupped_comms.append(comm_out) - elif is_pyop2_comm(comm_out): + dupped_comms.append(internal_comm) + elif is_pyop2_comm(internal_comm): # Inner comm is a PyOP2 comm, return it - incref(comm_out) + incref(internal_comm) else: raise ValueError("Inner comm is not a PyOP2 comm") - return comm_out + return internal_comm @collective def create_split_comm(comm): + """ Create a split communicator based on either shared memory access + if using MPI >= 3, or shared local disk access if using MPI >= 3. + Used internally for creating compilation communicators + + :arg comm: A communicator to split + + :return split_comm: A split communicator + """ if MPI.VERSION >= 3: debug("Creating compilation communicator using MPI_Split_type") split_comm = comm.Split_type(MPI.COMM_TYPE_SHARED) @@ -316,7 +326,7 @@ def create_split_comm(comm): split_comm = comm.Split(color=min(ranks), key=comm.rank) debug("Finished creating compilation communicator using filesystem colors") # Name - split_comm.Set_name(f"{comm.name or id(comm)}_COMPILATION") + split_comm.Set_name(f"{comm.name or comm.py2f()}_COMPILATION") # Refcount split_comm.Set_attr(refcount_keyval, [0]) incref(split_comm) @@ -327,31 +337,31 @@ def get_compilation_comm(comm): return comm.Get_attr(compilationcomm_keyval) -def set_compilation_comm(comm, inner): - """Set the compilation communicator. +def set_compilation_comm(comm, comp_comm): + """Stash the compilation communicator (`comp_comm`) on the + PyOP2 communicator `comm` :arg comm: A PyOP2 Communicator - :arg inner: The compilation communicator + :arg comp_comm: The compilation communicator """ - # Ensure `comm` is a PyOP2 comm if not is_pyop2_comm(comm): raise ValueError("Compilation communicator must be stashed on a PyOP2 comm") # Check if the compilation communicator is already set - old_inner = comm.Get_attr(compilationcomm_keyval) - if old_inner is not None: - if is_pyop2_comm(old_inner): + old_comp_comm = comm.Get_attr(compilationcomm_keyval) + if old_comp_comm is not None: + if is_pyop2_comm(old_comp_comm): raise ValueError("Compilation communicator is not a PyOP2 comm, something is very broken!") else: - decref(old_inner) + decref(old_comp_comm) - if not is_pyop2_comm(inner): + if not is_pyop2_comm(comp_comm): raise ValueError( "Communicator used for compilation communicator must be a PyOP2 communicator.\n" "Use pyop2.mpi.dup_comm() to create a PyOP2 comm from an existing comm.") else: - # Stash `inner` as an attribute on `comm` - comm.Set_attr(compilationcomm_keyval, inner) + # Stash `comp_comm` as an attribute on `comm` + comm.Set_attr(compilationcomm_keyval, comp_comm) @collective @@ -390,14 +400,9 @@ def free_comm(comm): This only actually calls MPI_Comm_free once the refcount drops to zero. """ - # ~ if isinstance(comm, list): - # ~ import pytest; pytest.set_trace() if comm != MPI.COMM_NULL: assert is_pyop2_comm(comm) ocomm = comm.Get_attr(outercomm_keyval) - if isinstance(ocomm, list): - # No idea why this happens!? - raise ValueError("Why have we got a list!?") if ocomm is not None: icomm = ocomm.Get_attr(innercomm_keyval) if icomm is None: @@ -451,7 +456,10 @@ def free_comms(): def hash_comm(comm): """Return a hashable identifier for a communicator.""" - assert is_pyop2_comm(comm) + if not is_pyop2_comm(comm): + ValueError("`comm` passed to `hash_comm()` must be a PyOP2 communicator") + # `comm` must be a PyOP2 communicator so we can use its id() + # as the hash and this is stable between invocations. return id(comm) diff --git a/pyop2/types/dataset.py b/pyop2/types/dataset.py index cbeb844fb..14b9b6400 100644 --- a/pyop2/types/dataset.py +++ b/pyop2/types/dataset.py @@ -37,7 +37,8 @@ def __init__(self, iter_set, dim=1, name=None): self._initialized = True def __del__(self): - # Cannot use hasattr here + # Cannot use hasattr here, since we define `__getattr__` + # This causes infinite recursion when looked up! if "comm" in self.__dict__: mpi.decref(self.comm) diff --git a/pyop2/types/glob.py b/pyop2/types/glob.py index d40e2f37d..751a33792 100644 --- a/pyop2/types/glob.py +++ b/pyop2/types/glob.py @@ -54,10 +54,6 @@ def __init__(self, dim, data=None, dtype=None, name=None, comm=None): # Object versioning setup petsc_counter = (comm and self.dtype == PETSc.ScalarType) VecAccessMixin.__init__(self, petsc_counter=petsc_counter) - try: - name = self.comm.name - except AttributeError: - name = "None" def __del__(self): if hasattr(self, "comm"): diff --git a/pyop2/types/set.py b/pyop2/types/set.py index 2615edd1a..1f6ea30c8 100644 --- a/pyop2/types/set.py +++ b/pyop2/types/set.py @@ -79,7 +79,8 @@ def __init__(self, size, name=None, halo=None, comm=None): self._cache = {} def __del__(self): - # Cannot use hasattr here + # Cannot use hasattr here, since child classes define `__getattr__` + # This causes infinite recursion when looked up! if "comm" in self.__dict__: mpi.decref(self.comm) From 5563c8500e935b4b4fab010e4c62ea58a4d51c7c Mon Sep 17 00:00:00 2001 From: JDBetteridge Date: Thu, 27 Oct 2022 13:03:26 +0100 Subject: [PATCH 12/17] pyop2_comm_status() now returns a string --- pyop2/logger.py | 3 --- pyop2/mpi.py | 16 ++++++++-------- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/pyop2/logger.py b/pyop2/logger.py index 833eeb8c2..2e58e3446 100644 --- a/pyop2/logger.py +++ b/pyop2/logger.py @@ -40,9 +40,6 @@ handler = logging.StreamHandler() logger.addHandler(handler) -fhandler = logging.FileHandler('pyop2.log') -logger.addHandler(fhandler) - debug = logger.debug info = logger.info diff --git a/pyop2/mpi.py b/pyop2/mpi.py index 4e17d3768..ccfc993ce 100644 --- a/pyop2/mpi.py +++ b/pyop2/mpi.py @@ -168,18 +168,19 @@ def is_pyop2_comm(comm): def pyop2_comm_status(): """ Prints the reference counts for all comms PyOP2 has duplicated """ - print('PYOP2 Communicator reference counts:') - print('| Communicator name | Count |') - print('==================================================') + status_string = 'PYOP2 Communicator reference counts:\n' + status_string += '| Communicator name | Count |\n' + status_string += '==================================================\n' for comm in dupped_comms: if comm == MPI.COMM_NULL: null = 'COMM_NULL' - print(f'| {null:39}| {0:5d} |') + status_string += f'| {null:39}| {0:5d} |\n' else: refcount = comm.Get_attr(refcount_keyval)[0] if refcount is None: refcount = -999 - print(f'| {comm.name:39}| {refcount:5d} |') + status_string += f'| {comm.name:39}| {refcount:5d} |\n' + return status_string class PyOP2Comm: @@ -429,15 +430,14 @@ def free_comm(comm): @atexit.register def free_comms(): """Free all outstanding communicators.""" - # Collect garbage as it may hold on to communicator references global PYOP2_FINALIZED PYOP2_FINALIZED = True debug("PyOP2 Finalizing") + # Collect garbage as it may hold on to communicator references debug("Calling gc.collect()") import gc gc.collect() - pyop2_comm_status() - print(dupped_comms) + debug(pyop2_comm_status()) debug(f"Freeing comms in list (length {len(dupped_comms)})") while dupped_comms: c = dupped_comms[-1] From f01762daff02b9875fece3865dbe2d47e571c356 Mon Sep 17 00:00:00 2001 From: JDBetteridge Date: Thu, 27 Oct 2022 13:15:28 +0100 Subject: [PATCH 13/17] Duplicate COMM_WORLD and COMM_SELF for PyOP2 use (and avoid renaming MPI_COMM_WORLD) --- pyop2/mpi.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pyop2/mpi.py b/pyop2/mpi.py index ccfc993ce..6f69718db 100644 --- a/pyop2/mpi.py +++ b/pyop2/mpi.py @@ -48,10 +48,10 @@ # These are user-level communicators, we never send any messages on # them inside PyOP2. -COMM_WORLD = PETSc.COMM_WORLD.tompi4py() +COMM_WORLD = PETSc.COMM_WORLD.tompi4py().Dup() COMM_WORLD.Set_name("PYOP2_COMM_WORLD") -COMM_SELF = PETSc.COMM_SELF.tompi4py() +COMM_SELF = PETSc.COMM_SELF.tompi4py().Dup() COMM_SELF.Set_name("PYOP2_COMM_SELF") PYOP2_FINALIZED = False @@ -452,6 +452,8 @@ def free_comms(): outercomm_keyval, compilationcomm_keyval]: MPI.Comm.Free_keyval(kv) + COMM_WORLD.Free() + COMM_SELF.Free() def hash_comm(comm): From 235d45befab56be2d52bb8bd42f42f449d2ecb19 Mon Sep 17 00:00:00 2001 From: JDBetteridge Date: Fri, 4 Nov 2022 16:15:31 +0000 Subject: [PATCH 14/17] Fixed some unreachable lines and redundant logic --- pyop2/mpi.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/pyop2/mpi.py b/pyop2/mpi.py index 6f69718db..c0ffa5e1e 100644 --- a/pyop2/mpi.py +++ b/pyop2/mpi.py @@ -154,8 +154,7 @@ def is_pyop2_comm(comm): ispyop2comm = False elif comm == MPI.COMM_NULL: if PYOP2_FINALIZED is False: - raise ValueError("COMM_NULL") - ispyop2comm = True + raise ValueError("Communicator passed to is_pyop2_comm() is COMM_NULL") else: ispyop2comm = True elif isinstance(comm, MPI.Comm): @@ -252,9 +251,7 @@ def decref(comm): if refcount[0] == 0: dupped_comms.remove(comm) free_comm(comm) - elif comm == MPI.COMM_NULL: - pass - else: + elif comm != MPI.COMM_NULL: free_comm(comm) From 2c6056b6c2332e1667e1a7e46159bea7cc090267 Mon Sep 17 00:00:00 2001 From: JDBetteridge Date: Tue, 15 Nov 2022 16:47:38 +0000 Subject: [PATCH 15/17] Change debug to print in _free_comms as stream already closed --- pyop2/mpi.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/pyop2/mpi.py b/pyop2/mpi.py index c0ffa5e1e..22e040764 100644 --- a/pyop2/mpi.py +++ b/pyop2/mpi.py @@ -40,7 +40,7 @@ import os from pyop2.configuration import configuration from pyop2.exceptions import CompilationError -from pyop2.logger import warning, debug +from pyop2.logger import warning, debug, logger, DEBUG from pyop2.utils import trim @@ -425,10 +425,14 @@ def free_comm(comm): @atexit.register -def free_comms(): +def _free_comms(): """Free all outstanding communicators.""" global PYOP2_FINALIZED PYOP2_FINALIZED = True + if logger.level > DEBUG: + debug = lambda string: None + else: + debug = lambda string: print(string) debug("PyOP2 Finalizing") # Collect garbage as it may hold on to communicator references debug("Calling gc.collect()") @@ -442,7 +446,7 @@ def free_comms(): refcount = c.Get_attr(refcount_keyval) debug(f"Freeing {c.name}, which has refcount {refcount[0]}") else: - debug("Freeing non PyOP2 comm in `free_comms()`") + debug("Freeing non PyOP2 comm in `_free_comms()`") free_comm(c) for kv in [refcount_keyval, innercomm_keyval, From ba86a166ad04b9c3508d295b67d34ab7204f7d04 Mon Sep 17 00:00:00 2001 From: JDBetteridge Date: Fri, 18 Nov 2022 19:06:35 +0000 Subject: [PATCH 16/17] Fixed removing comm from list twice on free --- pyop2/mpi.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pyop2/mpi.py b/pyop2/mpi.py index 22e040764..83211af14 100644 --- a/pyop2/mpi.py +++ b/pyop2/mpi.py @@ -249,7 +249,6 @@ def decref(comm): refcount = comm.Get_attr(refcount_keyval) refcount[0] -= 1 if refcount[0] == 0: - dupped_comms.remove(comm) free_comm(comm) elif comm != MPI.COMM_NULL: free_comm(comm) From 4c439ea032c7cb38ad8d9d173de59eda1158fbc9 Mon Sep 17 00:00:00 2001 From: JDBetteridge Date: Wed, 23 Nov 2022 14:46:36 +0000 Subject: [PATCH 17/17] Tidy code, address review comments --- pyop2/mpi.py | 72 +++++++++++++++++++++--------------------- pyop2/types/dataset.py | 3 +- pyop2/types/mat.py | 6 ++-- 3 files changed, 41 insertions(+), 40 deletions(-) diff --git a/pyop2/mpi.py b/pyop2/mpi.py index 83211af14..66fa10f88 100644 --- a/pyop2/mpi.py +++ b/pyop2/mpi.py @@ -37,14 +37,18 @@ from petsc4py import PETSc from mpi4py import MPI # noqa import atexit +import gc +import glob import os +import tempfile + from pyop2.configuration import configuration from pyop2.exceptions import CompilationError from pyop2.logger import warning, debug, logger, DEBUG from pyop2.utils import trim -__all__ = ("COMM_WORLD", "COMM_SELF", "MPI", "internal_comm", "is_pyop2_comm", "incref", "decref", "PyOP2Comm") +__all__ = ("COMM_WORLD", "COMM_SELF", "MPI", "internal_comm", "is_pyop2_comm", "incref", "decref", "temp_internal_comm") # These are user-level communicators, we never send any messages on # them inside PyOP2. @@ -113,7 +117,7 @@ def delcomm_outer(comm, keyval, icomm): :arg comm: Outer communicator. :arg keyval: The MPI keyval, should be ``innercomm_keyval``. :arg icomm: The inner communicator, should have a reference to - ``comm`. + ``comm``. """ if keyval != innercomm_keyval: raise ValueError("Unexpected keyval") @@ -143,9 +147,9 @@ def delcomm_outer(comm, keyval, icomm): def is_pyop2_comm(comm): - """Returns `True` if `comm` is a PyOP2 communicator, + """Returns ``True`` if ``comm`` is a PyOP2 communicator, False if `comm` another communicator. - Raises exception if `comm` is not a communicator. + Raises exception if ``comm`` is not a communicator. :arg comm: Communicator to query """ @@ -153,7 +157,7 @@ def is_pyop2_comm(comm): if isinstance(comm, PETSc.Comm): ispyop2comm = False elif comm == MPI.COMM_NULL: - if PYOP2_FINALIZED is False: + if not PYOP2_FINALIZED: raise ValueError("Communicator passed to is_pyop2_comm() is COMM_NULL") else: ispyop2comm = True @@ -182,51 +186,54 @@ def pyop2_comm_status(): return status_string -class PyOP2Comm: +class temp_internal_comm: """ Use a PyOP2 internal communicator and increment and decrement the internal comm. :arg comm: Any communicator """ def __init__(self, comm): self.user_comm = comm - self.internal_comm = None + self.internal_comm = internal_comm(self.user_comm) + + def __del__(self): + decref(self.internal_comm) def __enter__(self): - """ Returns an internal comm tat will be safely decref'd - when leaving the context manager + """ Returns an internal comm that will be safely decref'd + when the context manager is destroyed :returns pyop2_comm: A PyOP2 internal communicator """ - self.internal_comm = internal_comm(self.user_comm) return self.internal_comm def __exit__(self, exc_type, exc_value, traceback): - decref(self.internal_comm) - self.internal_comm = None + pass def internal_comm(comm): - """ Creates an internal comm from the user comm + """ Creates an internal comm from the user comm. + If comm is None, create an internal communicator from COMM_WORLD :arg comm: A communicator or None :returns pyop2_comm: A PyOP2 internal communicator """ + # Parse inputs if comm is None: # None will be the default when creating most objects - pyop2_comm = dup_comm(COMM_WORLD) - elif is_pyop2_comm(comm): - # Increase the reference count and return same comm if - # already an internal communicator - incref(comm) - pyop2_comm = comm + comm = COMM_WORLD elif isinstance(comm, PETSc.Comm): - # Convert PETSc.Comm to mpi4py.MPI.Comm - pyop2_comm = dup_comm(comm.tompi4py()) - elif comm == MPI.COMM_NULL: - # Ensure comm is not the NULL communicator + comm = comm.tompi4py() + + # Check for invalid inputs + if comm == MPI.COMM_NULL: raise ValueError("MPI_COMM_NULL passed to internal_comm()") elif not isinstance(comm, MPI.Comm): raise ValueError("Don't know how to dup a %r" % type(comm)) + + # Handle a valid input + if is_pyop2_comm(comm): + incref(comm) + pyop2_comm = comm else: pyop2_comm = dup_comm(comm) return pyop2_comm @@ -243,7 +250,6 @@ def incref(comm): def decref(comm): """ Decrement communicator reference count """ - global PYOP2_FINALIZED if not PYOP2_FINALIZED: assert is_pyop2_comm(comm) refcount = comm.Get_attr(refcount_keyval) @@ -287,7 +293,7 @@ def dup_comm(comm_in): @collective def create_split_comm(comm): """ Create a split communicator based on either shared memory access - if using MPI >= 3, or shared local disk access if using MPI >= 3. + if using MPI >= 3, or shared local disk access if using MPI <= 3. Used internally for creating compilation communicators :arg comm: A communicator to split @@ -300,7 +306,6 @@ def create_split_comm(comm): debug("Finished creating compilation communicator using MPI_Split_type") else: debug("Creating compilation communicator using MPI_Split + filesystem") - import tempfile if comm.rank == 0: if not os.path.exists(configuration["cache_dir"]): os.makedirs(configuration["cache_dir"], exist_ok=True) @@ -316,7 +321,6 @@ def create_split_comm(comm): with open(os.path.join(tmpname, str(comm.rank)), "wb"): pass comm.barrier() - import glob ranks = sorted(int(os.path.basename(name)) for name in glob.glob("%s/[0-9]*" % tmpname)) debug("Creating compilation communicator using filesystem colors") @@ -335,8 +339,8 @@ def get_compilation_comm(comm): def set_compilation_comm(comm, comp_comm): - """Stash the compilation communicator (`comp_comm`) on the - PyOP2 communicator `comm` + """Stash the compilation communicator (``comp_comm``) on the + PyOP2 communicator ``comm`` :arg comm: A PyOP2 Communicator :arg comp_comm: The compilation communicator @@ -435,17 +439,13 @@ def _free_comms(): debug("PyOP2 Finalizing") # Collect garbage as it may hold on to communicator references debug("Calling gc.collect()") - import gc gc.collect() debug(pyop2_comm_status()) debug(f"Freeing comms in list (length {len(dupped_comms)})") while dupped_comms: c = dupped_comms[-1] - if is_pyop2_comm(c): - refcount = c.Get_attr(refcount_keyval) - debug(f"Freeing {c.name}, which has refcount {refcount[0]}") - else: - debug("Freeing non PyOP2 comm in `_free_comms()`") + refcount = c.Get_attr(refcount_keyval) + debug(f"Freeing {c.name}, which has refcount {refcount[0]}") free_comm(c) for kv in [refcount_keyval, innercomm_keyval, @@ -459,7 +459,7 @@ def _free_comms(): def hash_comm(comm): """Return a hashable identifier for a communicator.""" if not is_pyop2_comm(comm): - ValueError("`comm` passed to `hash_comm()` must be a PyOP2 communicator") + raise ValueError("`comm` passed to `hash_comm()` must be a PyOP2 communicator") # `comm` must be a PyOP2 communicator so we can use its id() # as the hash and this is stable between invocations. return id(comm) diff --git a/pyop2/types/dataset.py b/pyop2/types/dataset.py index 14b9b6400..4e114032a 100644 --- a/pyop2/types/dataset.py +++ b/pyop2/types/dataset.py @@ -376,7 +376,8 @@ def __init__(self, arg, dims=None): return self._dsets = arg try: - # Try/except may not be necessary, someone needs to think about this... + # Try to choose the comm to be the same as the first set + # of the MixedDataSet comm = self._process_args(arg, dims)[0][0].comm except AttributeError: comm = None diff --git a/pyop2/types/mat.py b/pyop2/types/mat.py index a3c65feef..aefd77de1 100644 --- a/pyop2/types/mat.py +++ b/pyop2/types/mat.py @@ -1113,7 +1113,7 @@ def mult(self, mat, x, y): a[0] = x.array_r else: x.array_r - with mpi.PyOP2Comm(x.comm) as comm: + with mpi.temp_internal_comm(x.comm) as comm: comm.bcast(a) return y.scale(a) else: @@ -1130,7 +1130,7 @@ def multTranspose(self, mat, x, y): a[0] = x.array_r else: x.array_r - with mpi.PyOP2Comm(x.comm) as comm: + with mpi.temp_internal_comm(x.comm) as comm: comm.bcast(a) y.scale(a) else: @@ -1155,7 +1155,7 @@ def multTransposeAdd(self, mat, x, y, z): a[0] = x.array_r else: x.array_r - with mpi.PyOP2Comm(x.comm) as comm: + with mpi.temp_internal_comm(x.comm) as comm: comm.bcast(a) if y == z: # Last two arguments are aliased.