Skip to content

Commit

Permalink
Merge pull request #2582 from firedrakeproject/ksagiyam/periodic_extr…
Browse files Browse the repository at this point in the history
…usion

Ksagiyam/periodic extrusion
  • Loading branch information
dham authored Mar 8, 2023
2 parents 19d1842 + a8c1c7a commit 4779660
Show file tree
Hide file tree
Showing 11 changed files with 608 additions and 133 deletions.
12 changes: 11 additions & 1 deletion firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,13 +792,15 @@ def build(self):
"interior_facet_horiz": op2.ON_INTERIOR_FACETS}
iteration_region = iteration_regions.get(self._integral_type, None)
extruded = self._mesh.extruded
extruded_periodic = self._mesh.extruded_periodic
constant_layers = extruded and not self._mesh.variable_layers

return op2.GlobalKernel(self._kinfo.kernel,
kernel_args,
iteration_region=iteration_region,
pass_layer_arg=self._kinfo.pass_layer_arg,
extruded=extruded,
extruded_periodic=extruded_periodic,
constant_layers=constant_layers,
subset=self._needs_subset)

Expand Down Expand Up @@ -860,8 +862,16 @@ def _get_map_arg(self, finat_element):
offset += offset
else:
offset = None
if self._mesh.extruded_periodic:
offset_quotient = eutils.calculate_dof_offset_quotient(finat_element)
if offset_quotient is not None:
offset_quotient = tuple(offset_quotient)
if self._integral_type in {"interior_facet", "interior_facet_vert"}:
offset_quotient += offset_quotient
else:
offset_quotient = None

map_arg = op2.MapKernelArg(arity, offset)
map_arg = op2.MapKernelArg(arity, offset, offset_quotient)
self._map_arg_cache[key] = map_arg
return map_arg

Expand Down
4 changes: 3 additions & 1 deletion firedrake/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,7 @@ def save_mesh(self, mesh, distribution_name=None, permutation_name=None):
path = self._path_to_topology_extruded(tmesh.name)
self.require_group(path)
self.set_attr(path, PREFIX_EXTRUDED + "_base_mesh", base_tmesh.name)
self.set_attr(path, PREFIX_EXTRUDED + "_periodic", tmesh.extruded_periodic)
self.set_attr(path, PREFIX_EXTRUDED + "_variable_layers", tmesh.variable_layers)
if tmesh.variable_layers:
# Save tmesh.layers, which contains (start layer, stop layer)-tuple for each cell
Expand Down Expand Up @@ -874,6 +875,7 @@ def load_mesh(self, name=DEFAULT_MESH_NAME, reorder=None, distribution_parameter
if tmesh_key in self._tmesh_cache:
tmesh = self._tmesh_cache[tmesh_key]
else:
periodic = self.get_attr(path, PREFIX_EXTRUDED + "_periodic") if self.has_attr(path, PREFIX_EXTRUDED + "_periodic") else False
variable_layers = self.get_attr(path, PREFIX_EXTRUDED + "_variable_layers")
if variable_layers:
cell = base_tmesh.ufl_cell()
Expand All @@ -898,7 +900,7 @@ def load_mesh(self, name=DEFAULT_MESH_NAME, reorder=None, distribution_parameter
lsf.bcastEnd(unit, layers_a, layers, MPI.REPLACE)
else:
layers = self.get_attr(path, PREFIX_EXTRUDED + "_layers")
tmesh = ExtrudedMeshTopology(base_tmesh, layers, name=tmesh_name)
tmesh = ExtrudedMeshTopology(base_tmesh, layers, periodic=periodic, name=tmesh_name)
self._tmesh_cache[tmesh_key] = tmesh
# -- Load mesh --
mesh_key = self._generate_mesh_key_from_names(name,
Expand Down
34 changes: 26 additions & 8 deletions firedrake/cython/dmcommon.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1168,6 +1168,7 @@ def create_section(mesh, nodes_per_entity, on_base=False, block_size=1):

variable = mesh.variable_layers
extruded = mesh.cell_set._extruded
extruded_periodic = mesh.cell_set._extruded_periodic
on_base_ = on_base
nodes_per_entity = np.asarray(nodes_per_entity, dtype=IntType)
if variable:
Expand All @@ -1176,7 +1177,10 @@ def create_section(mesh, nodes_per_entity, on_base=False, block_size=1):
if on_base:
nodes_per_entity = sum(nodes_per_entity[:, i] for i in range(2))
else:
nodes_per_entity = sum(nodes_per_entity[:, i]*(mesh.layers - i) for i in range(2))
if extruded_periodic:
nodes_per_entity = sum(nodes_per_entity[:, i]*(mesh.layers - 1) for i in range(2))
else:
nodes_per_entity = sum(nodes_per_entity[:, i]*(mesh.layers - i) for i in range(2))

section = PETSc.Section().create(comm=mesh._comm)

Expand Down Expand Up @@ -1242,7 +1246,7 @@ def get_cell_nodes(mesh,
np.ndarray[PetscInt, ndim=2, mode="c"] layer_extents
np.ndarray[PetscInt, ndim=2, mode="c"] cell_closures
np.ndarray[PetscInt, ndim=2, mode="c"] entity_orientations
bint is_swarm, variable
bint is_swarm, variable, extruded_periodic_1_layer

dm = mesh.topology_dm
is_swarm = type(dm) is PETSc.DMSwarm
Expand All @@ -1255,6 +1259,10 @@ def get_cell_nodes(mesh,
layer_extents = mesh.layer_extents
if offset is None:
raise ValueError("Offset cannot be None with variable layer extents")
# Special case: DoFs on the top layer are identified as those on the bottom layer.
extruded_periodic_1_layer = isinstance(mesh, firedrake.mesh.ExtrudedMeshTopology) and \
mesh.extruded_periodic and \
mesh.layers == 1 + 1
nclosure = cell_closures.shape[1]
# Extract ordering from FInAT element entity DoFs
ndofs_list = []
Expand Down Expand Up @@ -1315,15 +1323,25 @@ def get_cell_nodes(mesh,
if variable:
off += offset[flat_index[k]]*(layer_extents[c, 0] - layer_extents[entity, 0])
if entity_permutations is not None:
for j in range(ceil_ndofs[i]):
cell_nodes[cell, flat_index[k]] = off + entity_permutations_c[perm_offset + ceil_ndofs[i] * orient + j]
k += 1
if extruded_periodic_1_layer:
for j in range(ceil_ndofs[i]):
cell_nodes[cell, flat_index[k]] = off + entity_permutations_c[perm_offset + ceil_ndofs[i] * orient + j] % offset[flat_index[k]]
k += 1
else:
for j in range(ceil_ndofs[i]):
cell_nodes[cell, flat_index[k]] = off + entity_permutations_c[perm_offset + ceil_ndofs[i] * orient + j]
k += 1
perm_offset += ceil_ndofs[i] * num_orientations_c[i]
else:
# FInAT element must eventually add entity_permutations() method
for j in range(ceil_ndofs[i]):
cell_nodes[cell, flat_index[k]] = off + j
k += 1
if extruded_periodic_1_layer:
for j in range(ceil_ndofs[i]):
cell_nodes[cell, flat_index[k]] = off + j % offset[flat_index[k]]
k += 1
else:
for j in range(ceil_ndofs[i]):
cell_nodes[cell, flat_index[k]] = off + j
k += 1
CHKERR(PetscFree(ceil_ndofs))
CHKERR(PetscFree(flat_index))
if entity_permutations is not None:
Expand Down
52 changes: 52 additions & 0 deletions firedrake/extrusion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,58 @@ def calculate_dof_offset(finat_element):
return dof_offset


@functools.lru_cache()
def calculate_dof_offset_quotient(finat_element):
"""Return the offset quotient for each DoF within the base cell.
:arg finat_element: A FInAT element.
:returns: A numpy array containing the offset quotient for each DoF.
offset_quotient q of each DoF (in a local cell) is defined as
i // o, where i is the local DoF ID of the DoF on the entity and
o is the offset of that DoF computed in ``calculate_dof_offset()``.
Let DOF(e, l, i) represent a DoF on (base-)entity e on layer l that has local ID i
and suppose this DoF has offset o and offset_quotient q. In periodic extrusion it
is convenient to identify DOF(e, l, i) as DOF(e, l + q, i % o); this transformation
allows one to always work with the "unit cell" in which i < o always holds.
In FEA offset_quotient is 0 or 1.
Example::
local ID offset offset_quotient
2--2--2 2--2--2 1--1--1
| | | | | |
CG2 1 1 1 2 2 2 0 0 0
| | | | | |
0--0--0 2--2--2 0--0--0
+-----+ +-----+ +-----+
| 1 3 | | 4 4 | | 0 0 |
DG1 | | | | | |
| 0 2 | | 4 4 | | 0 0 |
+-----+ +-----+ +-----+
"""
# scalar-valued elements only
if isinstance(finat_element, finat.TensorFiniteElement):
finat_element = finat_element.base_element
if is_real_tensor_product_element(finat_element):
return None
dof_offset_quotient = numpy.zeros(finat_element.space_dimension(), dtype=IntType)
for (b, v), entities in finat_element.entity_dofs().items():
for entity, dof_indices in entities.items():
quotient = 1 if v == 0 and entity % 2 == 1 else 0
for i in dof_indices:
dof_offset_quotient[i] = quotient
if (dof_offset_quotient == 0).all():
# Avoid unnecessary codegen in pyop2/codegen/builder.
dof_offset_quotient = None
return dof_offset_quotient


def is_real_tensor_product_element(element):
"""Is the provided FInAT element a tensor product involving the real space?
Expand Down
17 changes: 12 additions & 5 deletions firedrake/functionspacedata.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,8 @@ def get_top_bottom_boundary_nodes(mesh, key, V):
offset,
sub_domain)
else:
if mesh.extruded_periodic and sub_domain == "top":
raise ValueError("Invalid subdomain 'top': 'top' boundary is identified as 'bottom' boundary in periodic extrusion")
idx = {"bottom": -2, "top": -1}[sub_domain]
section, indices, facet_points = V.cell_boundary_masks
facet = facet_points[idx]
Expand Down Expand Up @@ -397,7 +399,7 @@ class FunctionSpaceData(object):
"""
__slots__ = ("real_tensorproduct", "map_cache", "entity_node_lists",
"node_set", "cell_boundary_masks",
"interior_facet_boundary_masks", "offset",
"interior_facet_boundary_masks", "offset", "offset_quotient",
"extruded", "mesh", "global_numbering")

@PETSc.Log.EventDecorator()
Expand Down Expand Up @@ -436,6 +438,10 @@ def __init__(self, mesh, ufl_element):
self.offset = eutils.calculate_dof_offset(finat_element)
else:
self.offset = None
if isinstance(mesh, mesh_mod.ExtrudedMeshTopology) and mesh.extruded_periodic:
self.offset_quotient = eutils.calculate_dof_offset_quotient(finat_element)
else:
self.offset_quotient = None

self.entity_node_lists = get_entity_node_lists(mesh, (edofs_key, real_tensorproduct, eperm_key), entity_dofs, entity_permutations, global_numbering, self.offset)
self.node_set = node_set
Expand Down Expand Up @@ -478,26 +484,27 @@ def boundary_nodes(self, V, sub_domain):
return get_facet_closure_nodes(V.mesh(), key, V)

@PETSc.Log.EventDecorator()
def get_map(self, V, entity_set, map_arity, name, offset):
def get_map(self, V, entity_set, map_arity, name, offset, offset_quotient):
"""Return a :class:`pyop2.Map` from some topological entity to
degrees of freedom.
:arg V: The :class:`FunctionSpace` to create the map for.
:arg entity_set: The :class:`pyop2.Set` of entities to map from.
:arg map_arity: The arity of the resulting map.
:arg name: A name for the resulting map.
:arg offset: Map offset (for extruded)."""
:arg offset: Map offset (for extruded).
:arg offset_quotient: Map offset_quotient (for extruded)."""
# V is only really used for error checking and "name".
assert len(V) == 1, "get_map should not be called on MixedFunctionSpace"
entity_node_list = self.entity_node_lists[entity_set]

val = self.map_cache[entity_set]
if val is None:
val = op2.Map(entity_set, self.node_set,
map_arity,
entity_node_list,
("%s_"+name) % (V.name),
offset=offset)
offset=offset,
offset_quotient=offset_quotient)

self.map_cache[entity_set] = val
return val
Expand Down
13 changes: 10 additions & 3 deletions firedrake/functionspaceimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ def __init__(self, mesh, element, name=None):
self.real_tensorproduct = sdata.real_tensorproduct
self.extruded = sdata.extruded
self.offset = sdata.offset
self.offset_quotient = sdata.offset_quotient
self.cell_boundary_masks = sdata.cell_boundary_masks
self.interior_facet_boundary_masks = sdata.interior_facet_boundary_masks

Expand Down Expand Up @@ -564,7 +565,8 @@ def cell_node_map(self):
self.mesh().cell_set,
self.finat_element.space_dimension(),
"cell_node",
self.offset)
self.offset,
self.offset_quotient)

def interior_facet_node_map(self):
r"""Return the :class:`pyop2.types.map.Map` from interior facets to
Expand All @@ -573,11 +575,15 @@ def interior_facet_node_map(self):
offset = self.cell_node_map().offset
if offset is not None:
offset = numpy.append(offset, offset)
offset_quotient = self.cell_node_map().offset_quotient
if offset_quotient is not None:
offset_quotient = numpy.append(offset_quotient, offset_quotient)
return sdata.get_map(self,
self.mesh().interior_facets.set,
2*self.finat_element.space_dimension(),
"interior_facet_node",
offset)
offset,
offset_quotient)

def exterior_facet_node_map(self):
r"""Return the :class:`pyop2.types.map.Map` from exterior facets to
Expand All @@ -587,7 +593,8 @@ def exterior_facet_node_map(self):
self.mesh().exterior_facets.set,
self.finat_element.space_dimension(),
"exterior_facet_node",
self.offset)
self.offset,
self.offset_quotient)

def boundary_nodes(self, sub_domain):
r"""Return the boundary nodes for this :class:`~.FunctionSpace`.
Expand Down
Loading

0 comments on commit 4779660

Please sign in to comment.