Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove class-level caches from CheckpointFile #2810

Merged
merged 2 commits into from
Mar 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
244 changes: 105 additions & 139 deletions firedrake/checkpointing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import functools
import pickle
import weakref
from petsc4py.PETSc import ViewerHDF5
import ufl
from pyop2 import op2
Expand Down Expand Up @@ -517,10 +516,6 @@ class CheckpointFile(object):
One can also use different number of processes for saving and for loading.

"""
# Cache for loaded meshes.
_mesh_cache = weakref.WeakValueDictionary()
_tmesh_cache = weakref.WeakValueDictionary()

def __init__(self, filename, mode, comm=COMM_WORLD):
self.viewer = ViewerHDF5()
self.filename = filename
Expand Down Expand Up @@ -869,88 +864,67 @@ def load_mesh(self, name=DEFAULT_MESH_NAME, reorder=None, distribution_parameter
base_tmesh_name = self.get_attr(path, PREFIX_EXTRUDED + "_base_mesh")
base_tmesh = self._load_mesh_topology(base_tmesh_name, reorder, distribution_parameters)
base_tmesh.init()
tmesh_key = self._generate_mesh_key_from_names(tmesh_name,
base_tmesh._distribution_name,
base_tmesh._permutation_name)
if tmesh_key in self._tmesh_cache:
tmesh = self._tmesh_cache[tmesh_key]
periodic = self.get_attr(path, PREFIX_EXTRUDED + "_periodic") if self.has_attr(path, PREFIX_EXTRUDED + "_periodic") else False
variable_layers = self.get_attr(path, PREFIX_EXTRUDED + "_variable_layers")
if variable_layers:
cell = base_tmesh.ufl_cell()
element = ufl.VectorElement("DP" if cell.is_simplex() else "DQ", cell, 0, dim=2)
_ = self._load_function_space_topology(base_tmesh, element)
base_tmesh_key = self._generate_mesh_key_from_names(base_tmesh.name,
base_tmesh._distribution_name,
base_tmesh._permutation_name)
sd_key = self._get_shared_data_key_for_checkpointing(base_tmesh, element)
_, _, lsf = self._function_load_utils[base_tmesh_key + sd_key]
nroots, _, _ = lsf.getGraph()
layers_a = np.empty(nroots, dtype=utils.IntType)
layers_a_iset = PETSc.IS().createGeneral(layers_a, comm=self._comm)
layers_a_iset.setName("_".join([PREFIX_EXTRUDED, "layers_iset"]))
self.viewer.pushGroup(path)
layers_a_iset.load(self.viewer)
self.viewer.popGroup()
layers_a = layers_a_iset.getIndices()
layers = np.empty((base_tmesh.cell_set.total_size, 2), dtype=utils.IntType)
unit = MPI._typedict[np.dtype(utils.IntType).char]
lsf.bcastBegin(unit, layers_a, layers, MPI.REPLACE)
lsf.bcastEnd(unit, layers_a, layers, MPI.REPLACE)
else:
periodic = self.get_attr(path, PREFIX_EXTRUDED + "_periodic") if self.has_attr(path, PREFIX_EXTRUDED + "_periodic") else False
variable_layers = self.get_attr(path, PREFIX_EXTRUDED + "_variable_layers")
if variable_layers:
cell = base_tmesh.ufl_cell()
element = ufl.VectorElement("DP" if cell.is_simplex() else "DQ", cell, 0, dim=2)
_ = self._load_function_space_topology(base_tmesh, element)
base_tmesh_key = self._generate_mesh_key_from_names(base_tmesh.name,
base_tmesh._distribution_name,
base_tmesh._permutation_name)
sd_key = self._get_shared_data_key_for_checkpointing(base_tmesh, element)
_, _, lsf = self._function_load_utils[base_tmesh_key + sd_key]
nroots, _, _ = lsf.getGraph()
layers_a = np.empty(nroots, dtype=utils.IntType)
layers_a_iset = PETSc.IS().createGeneral(layers_a, comm=self._comm)
layers_a_iset.setName("_".join([PREFIX_EXTRUDED, "layers_iset"]))
self.viewer.pushGroup(path)
layers_a_iset.load(self.viewer)
self.viewer.popGroup()
layers_a = layers_a_iset.getIndices()
layers = np.empty((base_tmesh.cell_set.total_size, 2), dtype=utils.IntType)
unit = MPI._typedict[np.dtype(utils.IntType).char]
lsf.bcastBegin(unit, layers_a, layers, MPI.REPLACE)
lsf.bcastEnd(unit, layers_a, layers, MPI.REPLACE)
else:
layers = self.get_attr(path, PREFIX_EXTRUDED + "_layers")
tmesh = ExtrudedMeshTopology(base_tmesh, layers, periodic=periodic, name=tmesh_name)
self._tmesh_cache[tmesh_key] = tmesh
layers = self.get_attr(path, PREFIX_EXTRUDED + "_layers")
tmesh = ExtrudedMeshTopology(base_tmesh, layers, periodic=periodic, name=tmesh_name)
# -- Load mesh --
mesh_key = self._generate_mesh_key_from_names(name,
base_tmesh._distribution_name,
base_tmesh._permutation_name)
if mesh_key in self._mesh_cache:
mesh = self._mesh_cache[mesh_key]
else:
path = self._path_to_mesh(tmesh_name, name)
coord_element = self._unpickle(self.get_attr(path, PREFIX + "_coordinate_element"))
coord_name = self.get_attr(path, PREFIX + "_coordinates")
coordinates = self._load_function_topology(tmesh, coord_element, coord_name)
mesh = make_mesh_from_coordinates(coordinates, name)
if self.has_attr(path, PREFIX + "_radial_coordinates"):
radial_coord_element = self._unpickle(self.get_attr(path, PREFIX + "_radial_coordinate_element"))
radial_coord_name = self.get_attr(path, PREFIX + "_radial_coordinates")
radial_coordinates = self._load_function_topology(tmesh, radial_coord_element, radial_coord_name)
tV_radial_coord = impl.FunctionSpace(tmesh, radial_coord_element)
V_radial_coord = impl.WithGeometry.create(tV_radial_coord, mesh)
radial_coord_function_name = self.get_attr(path, PREFIX + "_radial_coordinate_function")
mesh.radial_coordinates = Function(V_radial_coord, val=radial_coordinates, name=radial_coord_function_name)
# The followings are conceptually redundant, but needed.
path = os.path.join(self._path_to_mesh(tmesh_name, name), PREFIX_EXTRUDED)
base_mesh_name = self.get_attr(path, PREFIX_EXTRUDED + "_base_mesh")
mesh._base_mesh = self.load_mesh(base_mesh_name)
self._mesh_cache[mesh_key] = mesh
path = self._path_to_mesh(tmesh_name, name)
coord_element = self._unpickle(self.get_attr(path, PREFIX + "_coordinate_element"))
coord_name = self.get_attr(path, PREFIX + "_coordinates")
coordinates = self._load_function_topology(tmesh, coord_element, coord_name)
mesh = make_mesh_from_coordinates(coordinates, name)
if self.has_attr(path, PREFIX + "_radial_coordinates"):
radial_coord_element = self._unpickle(self.get_attr(path, PREFIX + "_radial_coordinate_element"))
radial_coord_name = self.get_attr(path, PREFIX + "_radial_coordinates")
radial_coordinates = self._load_function_topology(tmesh, radial_coord_element, radial_coord_name)
tV_radial_coord = impl.FunctionSpace(tmesh, radial_coord_element)
V_radial_coord = impl.WithGeometry.create(tV_radial_coord, mesh)
radial_coord_function_name = self.get_attr(path, PREFIX + "_radial_coordinate_function")
mesh.radial_coordinates = Function(V_radial_coord, val=radial_coordinates, name=radial_coord_function_name)
# The followings are conceptually redundant, but needed.
path = os.path.join(self._path_to_mesh(tmesh_name, name), PREFIX_EXTRUDED)
base_mesh_name = self.get_attr(path, PREFIX_EXTRUDED + "_base_mesh")
mesh._base_mesh = self.load_mesh(base_mesh_name)
else:
utils._init()
# -- Load mesh topology --
tmesh = self._load_mesh_topology(tmesh_name, reorder, distribution_parameters)
mesh_key = self._generate_mesh_key_from_names(name,
tmesh._distribution_name,
tmesh._permutation_name)
if mesh_key in self._mesh_cache:
mesh = self._mesh_cache[mesh_key]
else:
# -- Load coordinates --
# tmesh.topology_dm has already been redistributed.
path = self._path_to_mesh(tmesh_name, name)
# Load firedrake coordinates directly.
# When implementing checkpointing for MeshHierarchy in the future,
# we will need to postpone calling tmesh.init().
tmesh.init()
coord_element = self._unpickle(self.get_attr(path, PREFIX + "_coordinate_element"))
coord_name = self.get_attr(path, PREFIX + "_coordinates")
coordinates = self._load_function_topology(tmesh, coord_element, coord_name)
mesh = make_mesh_from_coordinates(coordinates, name)
# Load plex coordinates for a complete representation of plex.
tmesh.topology_dm.coordinatesLoad(self.viewer, tmesh.sfXC)
self._mesh_cache[mesh_key] = mesh
# -- Load coordinates --
# tmesh.topology_dm has already been redistributed.
path = self._path_to_mesh(tmesh_name, name)
# Load firedrake coordinates directly.
# When implementing checkpointing for MeshHierarchy in the future,
# we will need to postpone calling tmesh.init().
tmesh.init()
coord_element = self._unpickle(self.get_attr(path, PREFIX + "_coordinate_element"))
coord_name = self.get_attr(path, PREFIX + "_coordinates")
coordinates = self._load_function_topology(tmesh, coord_element, coord_name)
mesh = make_mesh_from_coordinates(coordinates, name)
# Load plex coordinates for a complete representation of plex.
tmesh.topology_dm.coordinatesLoad(self.viewer, tmesh.sfXC)
return mesh

@PETSc.Log.EventDecorator("LoadMeshTopology")
Expand Down Expand Up @@ -989,65 +963,57 @@ def _load_mesh_topology(self, tmesh_name, reorder, distribution_parameters):
distribution_name = None
permutation_name = None
perm_is = None
# This is only to return the same tmesh object if the same set of arguments are given.
# Multiple tmesh_key might end up having the same value, but it is hard to process
# all distribution and reorder options at this stage (many things happen in MeshTopology constructor).
tmesh_key = self._generate_mesh_key(tmesh_name, distribution_name, permutation_name, reorder, distribution_parameters)
if tmesh_key in self._tmesh_cache:
tmesh = self._tmesh_cache[tmesh_key]
plex = PETSc.DMPlex()
plex.create(comm=self._comm)
plex.setName(tmesh_name)
# Check format
path = os.path.join(self._path_to_topology(tmesh_name), "topology")
if any(d not in self.h5pyfile for d in [os.path.join(path, "cells"),
os.path.join(path, "cones"),
os.path.join(path, "order"),
os.path.join(path, "orientation")]):
raise RuntimeError(f"Unsupported PETSc ViewerHDF5 format used in {self.filename}")
format = ViewerHDF5.Format.HDF5_PETSC
self.viewer.pushFormat(format=format)
plex.distributionSetName(distribution_name)
sfXB = plex.topologyLoad(self.viewer)
plex.distributionSetName(None)
self.viewer.popFormat()
if load_distribution_permutation:
chart_size = np.empty(1, dtype=utils.IntType)
chart_sizes_iset = PETSc.IS().createGeneral(chart_size, comm=self._comm)
chart_sizes_iset.setName("chart_sizes")
path = self._path_to_distribution(tmesh_name, distribution_name)
self.viewer.pushGroup(path)
chart_sizes_iset.load(self.viewer)
self.viewer.popGroup()
chart_size = chart_sizes_iset.getIndices().item()
perm = np.empty(chart_size, dtype=utils.IntType)
perm_is = PETSc.IS().createGeneral(perm, comm=self._comm)
path = self._path_to_permutation(tmesh_name, distribution_name, permutation_name)
self.viewer.pushGroup(path)
perm_is.setName("permutation")
perm_is.load(self.viewer)
perm_is.setName(None)
self.viewer.popGroup()
else:
plex = PETSc.DMPlex()
plex.create(comm=self._comm)
plex.setName(tmesh_name)
# Check format
path = os.path.join(self._path_to_topology(tmesh_name), "topology")
if any(d not in self.h5pyfile for d in [os.path.join(path, "cells"),
os.path.join(path, "cones"),
os.path.join(path, "order"),
os.path.join(path, "orientation")]):
raise RuntimeError(f"Unsupported PETSc ViewerHDF5 format used in {self.filename}")
format = ViewerHDF5.Format.HDF5_PETSC
self.viewer.pushFormat(format=format)
plex.distributionSetName(distribution_name)
sfXB = plex.topologyLoad(self.viewer)
plex.distributionSetName(None)
self.viewer.popFormat()
if load_distribution_permutation:
chart_size = np.empty(1, dtype=utils.IntType)
chart_sizes_iset = PETSc.IS().createGeneral(chart_size, comm=self._comm)
chart_sizes_iset.setName("chart_sizes")
path = self._path_to_distribution(tmesh_name, distribution_name)
self.viewer.pushGroup(path)
chart_sizes_iset.load(self.viewer)
self.viewer.popGroup()
chart_size = chart_sizes_iset.getIndices().item()
perm = np.empty(chart_size, dtype=utils.IntType)
perm_is = PETSc.IS().createGeneral(perm, comm=self._comm)
path = self._path_to_permutation(tmesh_name, distribution_name, permutation_name)
self.viewer.pushGroup(path)
perm_is.setName("permutation")
perm_is.load(self.viewer)
perm_is.setName(None)
self.viewer.popGroup()
else:
perm_is = None
# -- Construct Mesh (Topology) --
# Use public API so pass user comm (self.comm)
tmesh = MeshTopology(plex, name=plex.getName(), reorder=reorder,
distribution_parameters=distribution_parameters, sfXB=sfXB, perm_is=perm_is,
distribution_name=distribution_name, permutation_name=permutation_name,
comm=self.comm)
self.viewer.pushFormat(format=format)
# tmesh.topology_dm has already been redistributed.
sfXCtemp = tmesh.sfXB.compose(tmesh.sfBC) if tmesh.sfBC is not None else tmesh.sfXB
plex.labelsLoad(self.viewer, sfXCtemp)
self.viewer.popFormat()
# These labels are distribution dependent.
# We should be able to save/load labels selectively.
plex.removeLabel("pyop2_core")
plex.removeLabel("pyop2_owned")
plex.removeLabel("pyop2_ghost")
self._tmesh_cache[tmesh_key] = tmesh
perm_is = None
# -- Construct Mesh (Topology) --
# Use public API so pass user comm (self.comm)
tmesh = MeshTopology(plex, name=plex.getName(), reorder=reorder,
distribution_parameters=distribution_parameters, sfXB=sfXB, perm_is=perm_is,
distribution_name=distribution_name, permutation_name=permutation_name,
comm=self.comm)
self.viewer.pushFormat(format=format)
# tmesh.topology_dm has already been redistributed.
sfXCtemp = tmesh.sfXB.compose(tmesh.sfBC) if tmesh.sfBC is not None else tmesh.sfXB
plex.labelsLoad(self.viewer, sfXCtemp)
self.viewer.popFormat()
# These labels are distribution dependent.
# We should be able to save/load labels selectively.
plex.removeLabel("pyop2_core")
plex.removeLabel("pyop2_owned")
plex.removeLabel("pyop2_ghost")
return tmesh

@PETSc.Log.EventDecorator("LoadFunctionSpace")
Expand Down
Loading