Skip to content

Commit

Permalink
Replace __del__ with weakref.finalizer
Browse files Browse the repository at this point in the history
  • Loading branch information
JDBetteridge committed Nov 16, 2023
1 parent ca73fb6 commit fee60d2
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 60 deletions.
9 changes: 3 additions & 6 deletions pyop2/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import shlex
from hashlib import md5
from packaging.version import Version, InvalidVersion
import weakref


from pyop2 import mpi
Expand Down Expand Up @@ -189,13 +190,9 @@ def __init__(self, extra_compiler_flags=(), extra_linker_flags=(), cpp=False, co

# Compilation communicators are reference counted on the PyOP2 comm
self.pcomm = mpi.internal_comm(comm)
weakref.finalize(self, mpi.decref, self.pcomm)
self.comm = mpi.compilation_comm(self.pcomm)

def __del__(self):
if hasattr(self, "comm"):
mpi.decref(self.comm)
if hasattr(self, "pcomm"):
mpi.decref(self.pcomm)
weakref.finalize(self, mpi.decref, self.comm)

def __repr__(self):
return f"<{self._name} compiler, version {self.version or 'unknown'}>"
Expand Down
5 changes: 2 additions & 3 deletions pyop2/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import glob
import os
import tempfile
import weakref

from pyop2.configuration import configuration
from pyop2.exceptions import CompilationError
Expand Down Expand Up @@ -267,9 +268,7 @@ class temp_internal_comm:
def __init__(self, comm):
self.user_comm = comm
self.internal_comm = internal_comm(self.user_comm)

def __del__(self):
decref(self.internal_comm)
weakref.finalize(self, decref, self.internal_comm)

def __enter__(self):
""" Returns an internal comm that will be safely decref'd
Expand Down
6 changes: 2 additions & 4 deletions pyop2/parloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import operator
from dataclasses import dataclass
from typing import Any, Optional, Tuple
import weakref

import loopy as lp
import numpy as np
Expand Down Expand Up @@ -152,12 +153,9 @@ def __init__(self, global_knl, iterset, arguments):
self.global_kernel = global_knl
self.iterset = iterset
self.comm = mpi.internal_comm(iterset.comm)
weakref.finalize(self, mpi.decref, self.comm)
self.arguments, self.reduced_globals = self.prepare_reduced_globals(arguments, global_knl)

def __del__(self):
if hasattr(self, "comm"):
mpi.decref(self.comm)

@property
def local_kernel(self):
return self.global_kernel.local_kernel
Expand Down
7 changes: 3 additions & 4 deletions pyop2/types/dat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import ctypes
import itertools
import operator
import weakref

import loopy as lp
import numpy as np
Expand Down Expand Up @@ -83,16 +84,13 @@ def __init__(self, dataset, data=None, dtype=None, name=None):

self._dataset = dataset
self.comm = mpi.internal_comm(dataset.comm)
weakref.finalize(self, mpi.decref, self.comm)
self.halo_valid = True
self._name = name or "dat_#x%x" % id(self)

self._halo_frozen = False
self._frozen_access_mode = None

def __del__(self):
if hasattr(self, "comm"):
mpi.decref(self.comm)

@utils.cached_property
def _kernel_args_(self):
return (self._data.ctypes.data, )
Expand Down Expand Up @@ -824,6 +822,7 @@ def what(x):
raise ex.DataValueError('MixedDat with different dtypes is not supported')
# TODO: Think about different communicators on dats (c.f. MixedSet)
self.comm = mpi.internal_comm(self._dats[0].comm)
weakref.finalize(self, mpi.decref, self.comm)

@property
def dat_version(self):
Expand Down
10 changes: 4 additions & 6 deletions pyop2/types/dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numbers
import weakref

import numpy as np
from petsc4py import PETSc
Expand Down Expand Up @@ -30,18 +31,13 @@ def __init__(self, iter_set, dim=1, name=None):
if isinstance(iter_set, Subset):
raise NotImplementedError("Deriving a DataSet from a Subset is unsupported")
self.comm = mpi.internal_comm(iter_set.comm)
weakref.finalize(self, mpi.decref, self.comm)
self._set = iter_set
self._dim = utils.as_tuple(dim, numbers.Integral)
self._cdim = np.prod(self._dim).item()
self._name = name or "dset_#x%x" % id(self)
self._initialized = True

def __del__(self):
# Cannot use hasattr here, since we define `__getattr__`
# This causes infinite recursion when looked up!
if "comm" in self.__dict__:
mpi.decref(self.comm)

@classmethod
def _process_args(cls, *args, **kwargs):
return (args[0], ) + args, kwargs
Expand Down Expand Up @@ -212,6 +208,7 @@ def __init__(self, global_):
return
self._global = global_
self.comm = mpi.internal_comm(global_.comm)
weakref.finalize(self, mpi.decref, self.comm)
self._globalset = GlobalSet(comm=self.comm)
self._name = "gdset_#x%x" % id(self)
self._initialized = True
Expand Down Expand Up @@ -382,6 +379,7 @@ def __init__(self, arg, dims=None):
except AttributeError:
comm = None
self.comm = mpi.internal_comm(comm)
weakref.finalize(self, mpi.decref, self.comm)
self._initialized = True

@classmethod
Expand Down
12 changes: 5 additions & 7 deletions pyop2/types/glob.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import ctypes
import operator
import warnings
import weakref

import numpy as np
from petsc4py import PETSc
Expand All @@ -26,9 +27,9 @@ def __init__(self, dim, data=None, dtype=None, name=None):
self._buf = np.empty(self.shape, dtype=self.dtype)
self._name = name or "%s_#x%x" % (self.__class__.__name__.lower(), id(self))

def __del__(self):
if hasattr(self, "comm"):
mpi.decref(self.comm)
# ~ def __del__(self): # TODO !?
# ~ if hasattr(self, "comm"):
# ~ mpi.decref(self.comm)

@utils.cached_property
def _kernel_args_(self):
Expand Down Expand Up @@ -248,15 +249,12 @@ def __init__(self, dim, data=None, dtype=None, name=None, comm=None):
if comm is None:
warnings.warn("PyOP2.Global has no comm, this is likely to break in parallel!")
self.comm = mpi.internal_comm(comm)
weakref.finalize(self, mpi.decref, self.comm)

# Object versioning setup
petsc_counter = (comm and self.dtype == PETSc.ScalarType)
VecAccessMixin.__init__(self, petsc_counter=petsc_counter)

def __del__(self):
if hasattr(self, "comm"):
mpi.decref(self.comm)

def __str__(self):
return "OP2 Global Argument: %s with dim %s and value %s" \
% (self._name, self._dim, self._data)
Expand Down
9 changes: 5 additions & 4 deletions pyop2/types/map.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import itertools
import functools
import numbers
import weakref

import numpy as np

Expand Down Expand Up @@ -37,6 +38,7 @@ def __init__(self, iterset, toset, arity, values=None, name=None, offset=None, o
self._iterset = iterset
self._toset = toset
self.comm = mpi.internal_comm(toset.comm)
weakref.finalize(self, mpi.decref, self.comm)
self._arity = arity
self._values = utils.verify_reshape(values, dtypes.IntType,
(iterset.total_size, arity), allow_none=True)
Expand All @@ -53,10 +55,6 @@ def __init__(self, iterset, toset, arity, values=None, name=None, offset=None, o
# A cache for objects built on top of this map
self._cache = {}

def __del__(self):
if hasattr(self, "comm"):
mpi.decref(self.comm)

@utils.cached_property
def _kernel_args_(self):
return (self._values.ctypes.data, )
Expand Down Expand Up @@ -201,6 +199,7 @@ def __init__(self, map_, permutation):
raise NotImplementedError("PermutedMap of ComposedMap not implemented: simply permute before composing")
self.map_ = map_
self.comm = mpi.internal_comm(map_.comm)
weakref.finalize(self, mpi.decref, self.comm)
self.permutation = np.asarray(permutation, dtype=Map.dtype)
assert (np.unique(permutation) == np.arange(map_.arity, dtype=Map.dtype)).all()

Expand Down Expand Up @@ -252,6 +251,7 @@ def __init__(self, *maps_, name=None):
self._iterset = maps_[-1].iterset
self._toset = maps_[0].toset
self.comm = mpi.internal_comm(self._toset.comm)
weakref.finalize(self, mpi.decref, self.comm)
self._arity = maps_[0].arity
# Don't call super().__init__() to avoid calling verify_reshape()
self._values = None
Expand Down Expand Up @@ -316,6 +316,7 @@ def __init__(self, maps):
if len(comms) == 0:
raise ex.MapTypeError("Don't know how to make communicator")
self.comm = mpi.internal_comm(comms[0])
weakref.finalize(self, mpi.decref, self.comm)
self._initialized = True

@classmethod
Expand Down
29 changes: 13 additions & 16 deletions pyop2/types/mat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import abc
import ctypes
import itertools
import weakref

import numpy as np
from petsc4py import PETSc
Expand Down Expand Up @@ -69,10 +70,14 @@ def __init__(self, dsets, maps, *, iteration_regions=None, name=None, nest=None,
self._nrows = None if isinstance(dsets[0], GlobalDataSet) else self._rmaps[0].toset.size
self._ncols = None if isinstance(dsets[1], GlobalDataSet) else self._cmaps[0].toset.size
self.lcomm = mpi.internal_comm(dsets[0].comm if isinstance(dsets[0], GlobalDataSet) else self._rmaps[0].comm)
weakref.finalize(self, mpi.decref, self.lcomm)
self.rcomm = mpi.internal_comm(dsets[1].comm if isinstance(dsets[1], GlobalDataSet) else self._cmaps[0].comm)
weakref.finalize(self, mpi.decref, self.rcomm)
else:
self.lcomm = mpi.internal_comm(self._rmaps[0].comm)
weakref.finalize(self, mpi.decref, self.lcomm)
self.rcomm = mpi.internal_comm(self._cmaps[0].comm)
weakref.finalize(self, mpi.decref, self.rcomm)

rset, cset = self.dsets
# All rmaps and cmaps have the same data set - just use the first.
Expand All @@ -94,6 +99,7 @@ def __init__(self, dsets, maps, *, iteration_regions=None, name=None, nest=None,
if self.lcomm != self.rcomm:
raise ValueError("Haven't thought hard enough about different left and right communicators")
self.comm = mpi.internal_comm(self.lcomm)
weakref.finalize(self, mpi.decref, self.comm)
self._name = name or "sparsity_#x%x" % id(self)
self.iteration_regions = iteration_regions
# If the Sparsity is defined on MixedDataSets, we need to build each
Expand Down Expand Up @@ -129,14 +135,6 @@ def __init__(self, dsets, maps, *, iteration_regions=None, name=None, nest=None,
self._blocks = [[self]]
self._initialized = True

def __del__(self):
if hasattr(self, "comm"):
mpi.decref(self.comm)
if hasattr(self, "lcomm"):
mpi.decref(self.lcomm)
if hasattr(self, "rcomm"):
mpi.decref(self.rcomm)

_cache = {}

@classmethod
Expand Down Expand Up @@ -384,9 +382,12 @@ def __init__(self, parent, i, j):
self._blocks = [[self]]
self.iteration_regions = parent.iteration_regions
self.lcomm = mpi.internal_comm(self.dsets[0].comm)
weakref.finalize(self, mpi.decref, self.lcomm)
self.rcomm = mpi.internal_comm(self.dsets[1].comm)
weakref.finalize(self, mpi.decref, self.rcomm)
# TODO: think about lcomm != rcomm
self.comm = mpi.internal_comm(self.lcomm)
weakref.finalize(self, mpi.decref, self.comm)
self._initialized = True

@classmethod
Expand Down Expand Up @@ -446,21 +447,16 @@ class AbstractMat(DataCarrier, abc.ABC):
def __init__(self, sparsity, dtype=None, name=None):
self._sparsity = sparsity
self.lcomm = mpi.internal_comm(sparsity.lcomm)
weakref.finalize(self, mpi.decref, self.lcomm)
self.rcomm = mpi.internal_comm(sparsity.rcomm)
weakref.finalize(self, mpi.decref, self.rcomm)
self.comm = mpi.internal_comm(sparsity.comm)
weakref.finalize(self, mpi.decref, self.comm)
dtype = dtype or dtypes.ScalarType
self._datatype = np.dtype(dtype)
self._name = name or "mat_#x%x" % id(self)
self.assembly_state = Mat.ASSEMBLED

def __del__(self):
if hasattr(self, "comm"):
mpi.decref(self.comm)
if hasattr(self, "lcomm"):
mpi.decref(self.lcomm)
if hasattr(self, "rcomm"):
mpi.decref(self.rcomm)

@utils.validate_in(('access', _modes, ex.ModeValueError))
def __call__(self, access, path, lgmaps=None, unroll_map=False):
from pyop2.parloop import MatLegacyArg, MixedMatLegacyArg
Expand Down Expand Up @@ -959,6 +955,7 @@ def __init__(self, parent, i, j):
self.handle = parent.handle.getLocalSubMatrix(isrow=rowis,
iscol=colis)
self.comm = mpi.internal_comm(parent.comm)
weakref.finalize(self, mpi.decref, self.comm)
self.local_to_global_maps = self.handle.getLGMap()

@property
Expand Down
16 changes: 6 additions & 10 deletions pyop2/types/set.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import ctypes
import functools
import numbers
import weakref

import numpy as np

Expand Down Expand Up @@ -66,6 +67,7 @@ def _wrapper_cache_key_(self):
('name', str, ex.NameTypeError))
def __init__(self, size, name=None, halo=None, comm=None):
self.comm = mpi.internal_comm(comm)
weakref.finalize(self, mpi.decref, self.comm)
if isinstance(size, numbers.Integral):
size = [size] * 3
size = utils.as_tuple(size, numbers.Integral, 3)
Expand All @@ -78,12 +80,6 @@ def __init__(self, size, name=None, halo=None, comm=None):
# A cache of objects built on top of this set
self._cache = {}

def __del__(self):
# Cannot use hasattr here, since child classes define `__getattr__`
# This causes infinite recursion when looked up!
if "comm" in self.__dict__:
mpi.decref(self.comm)

@utils.cached_property
def core_size(self):
"""Core set size. Owned elements not touching halo elements."""
Expand Down Expand Up @@ -234,6 +230,7 @@ class GlobalSet(Set):

def __init__(self, comm=None):
self.comm = mpi.internal_comm(comm)
weakref.finalize(self, mpi.decref, self.comm)
self._cache = {}

@utils.cached_property
Expand Down Expand Up @@ -319,6 +316,7 @@ class ExtrudedSet(Set):
def __init__(self, parent, layers, extruded_periodic=False):
self._parent = parent
self.comm = mpi.internal_comm(parent.comm)
weakref.finalize(self, mpi.decref, self.comm)
try:
layers = utils.verify_reshape(layers, dtypes.IntType, (parent.total_size, 2))
self.constant_layers = False
Expand Down Expand Up @@ -400,6 +398,7 @@ class Subset(ExtrudedSet):
('indices', (list, tuple, np.ndarray), TypeError))
def __init__(self, superset, indices):
self.comm = mpi.internal_comm(superset.comm)
weakref.finalize(self, mpi.decref, self.comm)

# sort and remove duplicates
indices = np.unique(indices)
Expand Down Expand Up @@ -544,12 +543,9 @@ def __init__(self, sets):
"All components of a MixedSet must have the same number of layers."
# TODO: do all sets need the same communicator?
self.comm = mpi.internal_comm(functools.reduce(lambda a, b: a or b, map(lambda s: s if s is None else s.comm, sets)))
weakref.finalize(self, mpi.decref, self.comm)
self._initialized = True

def __del__(self):
if self._initialized and hasattr(self, "comm"):
mpi.decref(self.comm)

@utils.cached_property
def _kernel_args_(self):
raise NotImplementedError
Expand Down

0 comments on commit fee60d2

Please sign in to comment.