Skip to content

Commit

Permalink
add type hints to mesh partitioning
Browse files Browse the repository at this point in the history
  • Loading branch information
majosm committed Jun 27, 2022
1 parent a969e78 commit 1dbee5c
Showing 1 changed file with 56 additions and 37 deletions.
93 changes: 56 additions & 37 deletions meshmode/mesh/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from functools import reduce
from numbers import Real
from typing import Optional, Union
from typing import Optional, Union, Any, Tuple, Dict, List, Set

from dataclasses import dataclass

Expand All @@ -32,7 +32,10 @@
import modepy as mp

from meshmode.mesh import (
MeshElementGroup,
Mesh,
BTAG_PARTITION,
PartID,
InteriorAdjacencyGroup,
BoundaryAdjacencyGroup,
InterPartAdjacencyGroup
Expand Down Expand Up @@ -81,14 +84,17 @@ def find_group_indices(groups, meshwide_elems):
# {{{ partition_mesh

def _compute_global_elem_to_part_elem(
nelements, part_id_to_elements, part_id_to_part_index, element_id_dtype):
nelements: int,
part_id_to_elements: Dict[PartID, np.ndarray],
part_id_to_part_index: Dict[PartID, int],
element_id_dtype: Any) -> np.ndarray:
"""
Create a map from global element index to part-wide element index for a set of
parts.
:arg nelements: The number of elements in the global mesh.
:arg part_id_to_elements: A :class:`dict` mapping part identifiers to
sets of elements.
:arg part_id_to_elements: A :class:`dict` mapping a part identifier to
a sorted :class:`numpy.ndarray` of elements.
:arg part_id_to_part_index: A mapping from part identifiers to indices in
the range ``[0, num_parts)``.
:arg element_id_dtype: The element index data type.
Expand All @@ -107,7 +113,10 @@ def _compute_global_elem_to_part_elem(
return global_elem_to_part_elem


def _filter_mesh_groups(mesh, selected_elements, vertex_id_dtype):
def _filter_mesh_groups(
mesh: Mesh,
selected_elements: np.ndarray,
vertex_id_dtype: Any) -> Tuple[List, np.ndarray]:
"""
Create new mesh groups containing a selected subset of elements.
Expand Down Expand Up @@ -173,7 +182,10 @@ def _filter_mesh_groups(mesh, selected_elements, vertex_id_dtype):


def _get_connected_parts(
mesh, part_id_to_part_index, global_elem_to_part_elem, self_part_id):
mesh: Mesh,
part_id_to_part_index: Dict[PartID, int],
global_elem_to_part_elem: np.ndarray,
self_part_id: PartID) -> "Set[PartID]":
"""
Find the parts that are connected to the current part.
Expand Down Expand Up @@ -218,8 +230,12 @@ def _get_connected_parts(
if part_index in connected_part_indices}


def _create_self_to_self_adjacency_groups(mesh, global_elem_to_part_elem,
self_part_index, self_mesh_groups, self_mesh_group_elem_base):
def _create_self_to_self_adjacency_groups(
mesh: Mesh,
global_elem_to_part_elem: np.ndarray,
self_part_index: int,
self_mesh_groups: List[MeshElementGroup],
self_mesh_group_elem_base: List[int]) -> List[List[InteriorAdjacencyGroup]]:
r"""
Create self-to-self facial adjacency groups for a partitioned mesh.
Expand All @@ -229,9 +245,9 @@ def _create_self_to_self_adjacency_groups(mesh, global_elem_to_part_elem,
:func:`_compute_global_elem_to_part_elem`` for details.
:arg self_part_index: The index of the part currently being created, in the
range ``[0, num_parts)``.
:arg self_mesh_groups: An array of :class:`~meshmode.mesh.MeshElementGroup`
:arg self_mesh_groups: A list of :class:`~meshmode.mesh.MeshElementGroup`
instances representing the partitioned mesh groups.
:arg self_mesh_group_elem_base: An array containing the starting part-wide
:arg self_mesh_group_elem_base: A list containing the starting part-wide
element index for each group in *self_mesh_groups*.
:returns: A list of lists of `~meshmode.mesh.InteriorAdjacencyGroup` instances
Expand Down Expand Up @@ -283,8 +299,13 @@ def _create_self_to_self_adjacency_groups(mesh, global_elem_to_part_elem,


def _create_self_to_other_adjacency_groups(
mesh, part_id_to_part_index, global_elem_to_part_elem, self_part_id,
self_mesh_groups, self_mesh_group_elem_base, connected_parts):
mesh: Mesh,
part_id_to_part_index: Dict[PartID, int],
global_elem_to_part_elem: np.ndarray,
self_part_id: PartID,
self_mesh_groups: List[MeshElementGroup],
self_mesh_group_elem_base: List[int],
connected_parts: Set[PartID]) -> List[List[InterPartAdjacencyGroup]]:
"""
Create self-to-other adjacency groups for the partitioned mesh.
Expand All @@ -295,9 +316,9 @@ def _create_self_to_other_adjacency_groups(
indices to part indices and part-wide element indices. See
:func:`_compute_global_elem_to_part_elem`` for details.
:arg self_part_id: The identifier of the part currently being created.
:arg self_mesh_groups: An array of `~meshmode.mesh.MeshElementGroup` instances
:arg self_mesh_groups: A list of `~meshmode.mesh.MeshElementGroup` instances
representing the partitioned mesh groups.
:arg self_mesh_group_elem_base: An array containing the starting part-wide
:arg self_mesh_group_elem_base: A list containing the starting part-wide
element index for each group in *self_mesh_groups*.
:arg connected_parts: A :class:`set` containing the parts connected to
the current one.
Expand Down Expand Up @@ -358,8 +379,12 @@ def _create_self_to_other_adjacency_groups(
return self_to_other_adj_groups


def _create_boundary_groups(mesh, global_elem_to_part_elem, self_part_index,
self_mesh_groups, self_mesh_group_elem_base):
def _create_boundary_groups(
mesh: Mesh,
global_elem_to_part_elem: np.ndarray,
self_part_index: PartID,
self_mesh_groups: List[MeshElementGroup],
self_mesh_group_elem_base: List[int]) -> List[List[BoundaryAdjacencyGroup]]:
"""
Create boundary groups for partitioned mesh.
Expand All @@ -369,9 +394,9 @@ def _create_boundary_groups(mesh, global_elem_to_part_elem, self_part_index,
:func:`_compute_global_elem_to_part_elem`` for details.
:arg self_part_index: The index of the part currently being created, in the
range ``[0, num_parts)``.
:arg self_mesh_groups: An array of `~meshmode.mesh.MeshElementGroup` instances
:arg self_mesh_groups: A list of `~meshmode.mesh.MeshElementGroup` instances
representing the partitioned mesh groups.
:arg self_mesh_group_elem_base: An array containing the starting part-wide
:arg self_mesh_group_elem_base: A list containing the starting part-wide
element index for each group in *self_mesh_groups*.
:returns: A list of lists of `~meshmode.mesh.BoundaryAdjacencyGroup` instances
Expand Down Expand Up @@ -411,11 +436,14 @@ def _create_boundary_groups(mesh, global_elem_to_part_elem, self_part_index,
return bdry_adj_groups


def _get_mesh_part(mesh, part_id_to_elements, self_part_id):
def _get_mesh_part(
mesh: Mesh,
part_id_to_elements: Dict[PartID, np.ndarray],
self_part_id: PartID) -> Mesh:
"""
:arg mesh: A :class:`~meshmode.mesh.Mesh` to be partitioned.
:arg part_id_to_elements: A :class:`dict` mapping part identifiers to
sets of elements.
:arg part_id_to_elements: A :class:`dict` mapping a part identifier to
a sorted :class:`numpy.ndarray` of elements.
:arg self_part_id: The part identifier of the mesh to return.
:returns: A :class:`~meshmode.mesh.Mesh` containing a part of *mesh*.
Expand All @@ -429,9 +457,7 @@ def _get_mesh_part(mesh, part_id_to_elements, self_part_id):

part_id_to_part_index = {
part_id: part_index
for part_id, part_index in zip(
part_id_to_elements.keys(),
range(len(part_id_to_elements)))}
for part_index, part_id in enumerate(part_id_to_elements.keys())}

global_elem_to_part_elem = _compute_global_elem_to_part_elem(
mesh.nelements, part_id_to_elements, part_id_to_part_index,
Expand Down Expand Up @@ -477,21 +503,20 @@ def _get_mesh_part(mesh, part_id_to_elements, self_part_id):
+ boundary_adj_groups[igrp]
for igrp in range(len(self_mesh_groups))]

from meshmode.mesh import Mesh
self_mesh = Mesh(
return Mesh(
self_vertices,
self_mesh_groups,
facial_adjacency_groups=self_facial_adj_groups,
is_conforming=mesh.is_conforming)

return self_mesh


def partition_mesh(mesh, part_id_to_elements):
def partition_mesh(
mesh: Mesh,
part_id_to_elements: Dict[PartID, np.ndarray]) -> "Dict[PartID, Mesh]":
"""
:arg mesh: A :class:`~meshmode.mesh.Mesh` to be partitioned.
:arg part_id_to_elements: A :class:`dict` mapping part identifiers to sets of
elements.
:arg part_id_to_elements: A :class:`dict` mapping a part identifier to
a sorted :class:`numpy.ndarray` of elements.
:returns: A :class:`dict` mapping part identifiers to instances of
:class:`~meshmode.mesh.Mesh` that represent the corresponding part of
Expand Down Expand Up @@ -693,8 +718,6 @@ def perform_flips(mesh, flip_flags, skip_tests=False):

flip_flags = flip_flags.astype(bool)

from meshmode.mesh import Mesh

new_groups = []
for base_element_nr, grp in zip(mesh.base_element_nrs, mesh.groups):
grp_flip_flags = flip_flags[base_element_nr:base_element_nr + grp.nelements]
Expand Down Expand Up @@ -832,7 +855,6 @@ def merge_disjoint_meshes(meshes, skip_tests=False, single_group=False):

# }}}

from meshmode.mesh import Mesh
return Mesh(vertices, new_groups, skip_tests=skip_tests,
nodal_adjacency=nodal_adjacency,
facial_adjacency_groups=facial_adjacency_groups,
Expand Down Expand Up @@ -890,7 +912,6 @@ def split_mesh_groups(mesh, element_flags, return_subgroup_mapping=False):
element_nr_base=None, node_nr_base=None,
))

from meshmode.mesh import Mesh
mesh = Mesh(
vertices=mesh.vertices,
groups=new_groups,
Expand Down Expand Up @@ -1180,8 +1201,6 @@ def glue_mesh_boundaries(mesh, bdry_pair_mappings_and_tols, *, use_tree=None):
_match_boundary_faces(mesh, mapping, tol, use_tree=use_tree)
for mapping, tol in bdry_pair_mappings_and_tols]

from meshmode.mesh import InteriorAdjacencyGroup, BoundaryAdjacencyGroup

facial_adjacency_groups = []

for igrp, old_fagrp_list in enumerate(mesh.facial_adjacency_groups):
Expand Down

0 comments on commit 1dbee5c

Please sign in to comment.