Skip to content

Commit

Permalink
Merge pull request #2787 from firedrakeproject/ksagiyam/io_do_not_use…
Browse files Browse the repository at this point in the history
…_pickle

Ksagiyam/io do not use pickle
  • Loading branch information
ksagiyam authored Mar 13, 2023
2 parents aef07a7 + 29b1266 commit f528253
Show file tree
Hide file tree
Showing 10 changed files with 45 additions and 44 deletions.
6 changes: 3 additions & 3 deletions firedrake/bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import itertools

import ufl
from ufl import as_ufl, UFLException, as_tensor, VectorElement
from ufl import as_ufl, as_tensor, VectorElement
import finat

import pyop2 as op2
Expand Down Expand Up @@ -347,11 +347,11 @@ def function_arg(self, g):
try:
g = as_ufl(g)
self._function_arg = g
except UFLException:
except ValueError:
try:
# Recurse to handle this through interpolation.
self.function_arg = as_ufl(as_tensor(g))
except UFLException:
except ValueError:
raise ValueError(f"{g} is not a valid DirichletBC expression")

def homogenize(self):
Expand Down
34 changes: 23 additions & 11 deletions firedrake/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,15 +590,15 @@ def save_mesh(self, mesh, distribution_name=None, permutation_name=None):
if mesh.name not in self.require_group(path):
path = self._path_to_mesh(tmesh.name, mesh.name)
self.require_group(path)
self.set_attr(path, PREFIX + "_coordinate_element", self._pickle(mesh._coordinates.function_space().ufl_element()))
self._save_ufl_element(path, PREFIX + "_coordinate_element", mesh._coordinates.function_space().ufl_element())
self.set_attr(path, PREFIX + "_coordinates", mesh._coordinates.name())
self._save_function_topology(mesh._coordinates)
if hasattr(mesh, PREFIX + "_radial_coordinates"):
# Cannot do: self.save_function(mesh.radial_coordinates)
# This will cause infinite recursion.
self.set_attr(path, PREFIX + "_radial_coordinate_function", mesh.radial_coordinates.name())
radial_coordinates = mesh.radial_coordinates.topological
self.set_attr(path, PREFIX + "_radial_coordinate_element", self._pickle(radial_coordinates.function_space().ufl_element()))
self._save_ufl_element(path, PREFIX + "_radial_coordinate_element", radial_coordinates.function_space().ufl_element())
self.set_attr(path, PREFIX + "_radial_coordinates", radial_coordinates.name())
self._save_function_topology(radial_coordinates)
self._update_mesh_name_topology_name_map({mesh.name: tmesh.name})
Expand All @@ -616,7 +616,7 @@ def save_mesh(self, mesh, distribution_name=None, permutation_name=None):
path = self._path_to_mesh(tmesh.name, mesh.name)
self.require_group(path)
# Save Firedrake coodinates.
self.set_attr(path, PREFIX + "_coordinate_element", self._pickle(mesh._coordinates.function_space().ufl_element()))
self._save_ufl_element(path, PREFIX + "_coordinate_element", mesh._coordinates.function_space().ufl_element())
self.set_attr(path, PREFIX + "_coordinates", mesh._coordinates.name())
self._save_function_topology(mesh._coordinates)
# Save DMPlex coordinates for a complete representation of the plex.
Expand Down Expand Up @@ -719,11 +719,11 @@ def _save_function_space(self, V):
# Save UFL element
path = self._path_to_function_space(tmesh.name, mesh.name, V_name)
self.require_group(path)
self.set_attr(path, PREFIX + "_ufl_element", self._pickle(element))
# Test if the pickled UFL element matches the original element
loaded_element = self._unpickle(self.get_attr(path, PREFIX + "_ufl_element"))
self._save_ufl_element(path, PREFIX + "_ufl_element", element)
# Test if the loaded UFL element matches the original element
loaded_element = self._load_ufl_element(path, PREFIX + "_ufl_element")
if loaded_element != element:
raise RuntimeError(f"pickled UFL element ({loaded_element}) does not match the original element ({element})")
raise RuntimeError(f"Loaded UFL element ({loaded_element}) does not match the original element ({element})")

@PETSc.Log.EventDecorator("SaveFunctionSpaceTopology")
def _save_function_space_topology(self, tV):
Expand Down Expand Up @@ -892,12 +892,12 @@ def load_mesh(self, name=DEFAULT_MESH_NAME, reorder=None, distribution_parameter
tmesh = ExtrudedMeshTopology(base_tmesh, layers, periodic=periodic, name=tmesh_name)
# -- Load mesh --
path = self._path_to_mesh(tmesh_name, name)
coord_element = self._unpickle(self.get_attr(path, PREFIX + "_coordinate_element"))
coord_element = self._load_ufl_element(path, PREFIX + "_coordinate_element")
coord_name = self.get_attr(path, PREFIX + "_coordinates")
coordinates = self._load_function_topology(tmesh, coord_element, coord_name)
mesh = make_mesh_from_coordinates(coordinates, name)
if self.has_attr(path, PREFIX + "_radial_coordinates"):
radial_coord_element = self._unpickle(self.get_attr(path, PREFIX + "_radial_coordinate_element"))
radial_coord_element = self._load_ufl_element(path, PREFIX + "_radial_coordinate_element")
radial_coord_name = self.get_attr(path, PREFIX + "_radial_coordinates")
radial_coordinates = self._load_function_topology(tmesh, radial_coord_element, radial_coord_name)
tV_radial_coord = impl.FunctionSpace(tmesh, radial_coord_element)
Expand All @@ -919,7 +919,7 @@ def load_mesh(self, name=DEFAULT_MESH_NAME, reorder=None, distribution_parameter
# When implementing checkpointing for MeshHierarchy in the future,
# we will need to postpone calling tmesh.init().
tmesh.init()
coord_element = self._unpickle(self.get_attr(path, PREFIX + "_coordinate_element"))
coord_element = self._load_ufl_element(path, PREFIX + "_coordinate_element")
coord_name = self.get_attr(path, PREFIX + "_coordinates")
coordinates = self._load_function_topology(tmesh, coord_element, coord_name)
mesh = make_mesh_from_coordinates(coordinates, name)
Expand Down Expand Up @@ -1039,7 +1039,7 @@ def _load_function_space(self, mesh, name):
elif self._is_function_space(tmesh.name, mesh.name, name):
# Load function space data
path = self._path_to_function_space(tmesh.name, mesh.name, name)
element = self._unpickle(self.get_attr(path, PREFIX + "_ufl_element"))
element = self._load_ufl_element(path, PREFIX + "_ufl_element")
tV = self._load_function_space_topology(tmesh, element)
# Construct function space
V = impl.WithGeometry.create(tV, mesh)
Expand Down Expand Up @@ -1353,6 +1353,18 @@ def _update_pickled_dict(self, name, new_item, *args):
the_dict.update(new_item)
getattr(self, "_set_" + name)(*args, the_dict)

def _save_ufl_element(self, path, name, elem):
self.set_attr(path, name + "_repr", repr(elem))

def _load_ufl_element(self, path, name):
if self.has_attr(path, name + "_repr"):
globals = {}
locals = {}
exec("from ufl import *", globals, locals)
return eval(self.get_attr(path, name + "_repr"), globals, locals)
else:
return self._unpickle(self.get_attr(path, name)) # backward compat.

def _set_mesh_name_topology_name_map(self, new_item):
path = self._path_to_topologies()
self._write_pickled_dict(path, PREFIX + "_mesh_name_topology_name_map", new_item)
Expand Down
2 changes: 1 addition & 1 deletion firedrake/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def callable(loops, f):
def _interpolator(V, tensor, expr, subset, arguments, access):
try:
expr = ufl.as_ufl(expr)
except ufl.UFLException:
except ValueError:
raise ValueError("Expecting to interpolate a UFL expression")
try:
to_element = create_element(V.ufl_element())
Expand Down
19 changes: 5 additions & 14 deletions firedrake/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import tsfc.logging # noqa: F401
import pyop2.logger # noqa: F401
import coffee.logger # noqa: F401
from ufl.log import ufl_logger

from pyop2.mpi import COMM_WORLD

Expand Down Expand Up @@ -72,10 +71,9 @@ def set_log_handlers(handlers=None, comm=COMM_WORLD):
handlers = {}

for package in packages:
if package != "UFL":
logger = logging.getLogger(package)
for handler in logger.handlers:
logger.removeHandler(handler)
logger = logging.getLogger(package)
for handler in logger.handlers:
logger.removeHandler(handler)

handler = handlers.get(package, None)
if handler is None:
Expand All @@ -85,10 +83,7 @@ def set_log_handlers(handlers=None, comm=COMM_WORLD):
if comm is not None and comm.rank != 0:
handler = logging.NullHandler()

if package == "UFL":
ufl_logger.set_handler(handler)
else:
logger.addHandler(handler)
logger.addHandler(handler)


def set_log_level(level):
Expand All @@ -101,11 +96,7 @@ def set_log_level(level):
"""
for package in packages:
if package == "UFL":
from ufl.log import ufl_logger as logger
logger = logger.get_logger()
else:
logger = logging.getLogger(package)
logger = logging.getLogger(package)
logger.setLevel(level)


Expand Down
3 changes: 2 additions & 1 deletion firedrake/slate/slac/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from firedrake.parameters import parameters
from firedrake.petsc import get_petsc_variables
from firedrake.utils import complex_mode, ScalarType_c, as_cstr
from ufl.log import GREEN
from gem.utils import groupby
from gem import impero_utils
from itertools import chain
Expand All @@ -53,6 +52,8 @@

__all__ = ['compile_expression']

GREEN = "\033[1;37;32m%s\033[0m"


try:
PETSC_DIR, PETSC_ARCH = get_petsc_dir()
Expand Down
10 changes: 4 additions & 6 deletions firedrake/ufl_expr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import ufl
import ufl.argument
from ufl.assertions import ufl_assert
from ufl.split_functions import split
from ufl.algorithms import extract_arguments, extract_coefficients

Expand Down Expand Up @@ -63,11 +62,10 @@ def reconstruct(self, function_space=None,
if number is self._number and part is self._part \
and function_space is self.function_space():
return self
ufl_assert(isinstance(number, int),
"Expecting an int, not %s" % number)
ufl_assert(function_space.ufl_element().value_shape()
== self.ufl_element().value_shape(),
"Cannot reconstruct an Argument with a different value shape.")
if not isinstance(number, int):
raise TypeError(f"Expecting an int, not {number}")
if function_space.ufl_element().value_shape() != self.ufl_element().value_shape():
raise ValueError("Cannot reconstruct an Argument with a different value shape.")
return Argument(function_space, number, part=part)


Expand Down
2 changes: 1 addition & 1 deletion tests/extrusion/test_cellvolume_extrusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_miscellaneous():
assert np.allclose(assemble(CellVolume(mesh)('+')*dS_v), sqrt(2) - 1)
assert np.allclose(assemble(CellVolume(mesh)('-')*dS_v), sqrt(2) - 1)

with pytest.raises(UFLException):
with pytest.raises(ValueError):
assemble(FacetArea(mesh)*dx)

assert np.allclose(assemble(FacetArea(mesh)*ds_b), 0.5)
Expand Down
4 changes: 2 additions & 2 deletions tests/regression/test_cellcoordinate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_cell_coordinate_ds(mesh):

def test_cell_coordinate_dS_not_restricted():
mesh = UnitSquareMesh(1, 1)
with pytest.raises(UFLException):
with pytest.raises(ValueError):
assemble(CellCoordinate(mesh)[0]*dS)


Expand All @@ -35,7 +35,7 @@ def test_cell_coordinate_dS():


def test_facet_coordinate_dx(mesh):
with pytest.raises(UFLException):
with pytest.raises(ValueError):
assemble(FacetCoordinate(mesh)[0]*dx)


Expand Down
2 changes: 1 addition & 1 deletion tests/regression/test_cellvolume.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_miscellaneous():
assemble(CellVolume(mesh)('-')*dS)]),
[1 - 1/sqrt(2), 1/sqrt(2)])

with pytest.raises(UFLException):
with pytest.raises(ValueError):
assemble(FacetArea(mesh)*dx)

assert np.allclose(assemble(FacetArea(mesh)*ds), 2*(3 - sqrt(2)))
Expand Down
7 changes: 3 additions & 4 deletions tests/regression/test_coordinatederivative.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest
import numpy as np
from firedrake import *
from ufl.log import UFLException


@pytest.mark.skipif(utils.complex_mode, reason="Don't expect coordinate derivatives to work in complex")
Expand Down Expand Up @@ -89,11 +88,11 @@ def test_integral_scaling_edge_case():
u = Function(V)

J = u * u * dx
with pytest.raises(UFLException):
with pytest.raises(ValueError):
assemble(Constant(2.0) * derivative(J, X))
with pytest.raises(UFLException):
with pytest.raises(ValueError):
assemble(derivative(Constant(2.0) * derivative(J, X), X))
with pytest.raises(UFLException):
with pytest.raises(ValueError):
assemble(Constant(2.0) * derivative(derivative(J, X), X))


Expand Down

0 comments on commit f528253

Please sign in to comment.