Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add disk caching for interpolation kernels #2348

Merged
merged 3 commits into from
Apr 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 50 additions & 15 deletions firedrake/interpolation.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import numpy
from functools import partial, singledispatch
import os
import tempfile

import FIAT
import ufl
from ufl.algorithms import extract_arguments
from ufl.algorithms import extract_arguments, extract_coefficients
from ufl.algorithms.signature import compute_expression_signature

from pyop2 import op2
from pyop2.caching import disk_cached

from tsfc.finatinterface import create_element, as_fiat_cell
from tsfc import compile_expression_dual_evaluation
Expand All @@ -14,7 +18,7 @@
import finat

import firedrake
from firedrake import utils
from firedrake import tsfc_interface, utils
from firedrake.adjoint import annotate_interpolate
from firedrake.petsc import PETSc

Expand Down Expand Up @@ -270,6 +274,11 @@ def _interpolator(V, tensor, expr, subset, arguments, access):
rt_var_name = 'rt_X'
to_element = rebuild(to_element, expr, rt_var_name)

cell_set = target_mesh.cell_set
if subset is not None:
assert subset.superset == cell_set
cell_set = subset

parameters = {}
parameters['scalar_type'] = utils.ScalarType

Expand All @@ -279,27 +288,21 @@ def _interpolator(V, tensor, expr, subset, arguments, access):
# FIXME: for the runtime unknown point set (for cross-mesh
# interpolation) we have to pass the finat element we construct
# here. Ideally we would only pass the UFL element through.
kernel = compile_expression_dual_evaluation(expr, to_element,
V.ufl_element(),
domain=source_mesh,
parameters=parameters)
kernel = compile_expression(cell_set.comm, expr, to_element, V.ufl_element(),
domain=source_mesh, parameters=parameters)
ast = kernel.ast
oriented = kernel.oriented
needs_cell_sizes = kernel.needs_cell_sizes
coefficients = kernel.coefficients
first_coeff_fake_coords = kernel.first_coefficient_fake_coords
coefficient_numbers = kernel.coefficient_numbers
needs_external_coords = kernel.needs_external_coords
name = kernel.name
kernel = op2.Kernel(ast, name, requires_zeroed_output_arguments=True,
flop_count=kernel.flop_count)
cell_set = target_mesh.cell_set
if subset is not None:
assert subset.superset == cell_set
cell_set = subset
parloop_args = [kernel, cell_set]

if first_coeff_fake_coords:
# Replace with real source mesh coordinates
coefficients[0] = source_mesh.coordinates
coefficients = tsfc_interface.extract_numbered_coefficients(expr, coefficient_numbers)
if needs_external_coords:
coefficients = [source_mesh.coordinates] + coefficients

if target_mesh is not source_mesh:
# NOTE: TSFC will sometimes drop run-time arguments in generated
Expand Down Expand Up @@ -381,6 +384,27 @@ def _interpolator(V, tensor, expr, subset, arguments, access):
return copyin + (parloop_compute_callable, ) + copyout


try:
_expr_cachedir = os.environ["FIREDRAKE_TSFC_KERNEL_CACHE_DIR"]
except KeyError:
_expr_cachedir = os.path.join(tempfile.gettempdir(),
f"firedrake-tsfc-expression-kernel-cache-uid{os.getuid()}")


def _compile_expression_key(comm, expr, to_element, ufl_element, domain, parameters):
"""Generate a cache key suitable for :func:`tsfc.compile_expression_dual_evaluation`."""
# Since the caching is collective, this function must return a 2-tuple of
# the form (comm, key) where comm is the communicator the cache is collective over.
# FIXME FInAT elements are not safely hashable so we ignore them here
key = _hash_expr(expr), hash(ufl_element), utils.tuplify(parameters)
return comm, key


@disk_cached({}, _expr_cachedir, key=_compile_expression_key, collective=True)
def compile_expression(comm, *args, **kwargs):
return compile_expression_dual_evaluation(*args, **kwargs)


@singledispatch
def rebuild(element, expr, rt_var_name):
raise NotImplementedError(f"Cross mesh interpolation not implemented for a {element} element.")
Expand Down Expand Up @@ -490,3 +514,14 @@ def __init__(self, glob):
self.dat = glob
self.cell_node_map = lambda *arguments: None
self.ufl_domain = lambda: None


def _hash_expr(expr):
"""Return a numbering-invariant hash of a UFL expression.

:arg expr: A UFL expression.
:returns: A numbering-invariant hash for the expression.
"""
domain_numbering = {d: i for i, d in enumerate(ufl.domain.extract_domains(expr))}
coefficient_numbering = {c: i for i, c in enumerate(extract_coefficients(expr))}
return compute_expression_signature(expr, {**domain_numbering, **coefficient_numbering})
8 changes: 4 additions & 4 deletions firedrake/preconditioners/pmg.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
add_hook, get_parent, push_parent, pop_parent,
get_function_space, set_function_space)
from firedrake.solving_utils import _SNESContext
from firedrake.tsfc_interface import extract_numbered_coefficients
from firedrake.utils import ScalarType_c, IntType_c
from firedrake.petsc import PETSc
import firedrake
Expand Down Expand Up @@ -500,10 +501,9 @@ def prolongation_transfer_kernel_action(Vf, expr):
from tsfc.finatinterface import create_element
to_element = create_element(Vf.ufl_element())
kernel = compile_expression_dual_evaluation(expr, to_element, Vf.ufl_element())
coefficients = kernel.coefficients
if kernel.first_coefficient_fake_coords:
target_mesh = Vf.ufl_domain()
coefficients[0] = target_mesh.coordinates
coefficients = extract_numbered_coefficients(expr, kernel.coefficient_numbers)
if kernel.needs_external_coords:
coefficients = [Vf.ufl_domain().coordinates] + coefficients

return op2.Kernel(kernel.ast, kernel.name,
requires_zeroed_output_arguments=True,
Expand Down
18 changes: 18 additions & 0 deletions firedrake/tsfc_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,3 +304,21 @@ def as_pyop2_local_kernel(ast, name, nargs, access=op2.INC, **kwargs):
accesses = tuple([access] + [op2.READ]*(nargs-1))
return op2.Kernel(ast, name, accesses=accesses,
requires_zeroed_output_arguments=True, **kwargs)


def extract_numbered_coefficients(expr, numbers):
"""Return expression coefficients specified by a numbering.

:arg expr: A UFL expression.
:arg numbers: Iterable of indices used for selecting the correct coefficients
from ``expr``.
:returns: A list of UFL coefficients.
"""
orig_coefficients = ufl.algorithms.extract_coefficients(expr)
coefficients = []
for coeff in (orig_coefficients[i] for i in numbers):
if type(coeff.ufl_element()) == ufl.MixedElement:
coefficients.extend(coeff.split())
else:
coefficients.append(coeff)
return coefficients