Skip to content

Commit

Permalink
Change for MixedMesh implementation (#718)
Browse files Browse the repository at this point in the history
* fix Parloop kernel arg ordering

* change for mixed mesh implementation
  • Loading branch information
ksagiyam authored Apr 3, 2024
1 parent f424fb5 commit 594e87b
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 13 deletions.
4 changes: 2 additions & 2 deletions pyop2/parloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def _kernel_args_(self):
@property
def map_kernel_args(self):
rmap, cmap = self.maps
return tuple(itertools.chain(*itertools.product(rmap._kernel_args_, cmap._kernel_args_)))
return tuple(itertools.chain(rmap._kernel_args_, cmap._kernel_args_))


@dataclass
Expand All @@ -143,7 +143,7 @@ def _kernel_args_(self):
@property
def map_kernel_args(self):
rmap, cmap = self.maps
return tuple(itertools.chain(*itertools.product(rmap._kernel_args_, cmap._kernel_args_)))
return tuple(itertools.chain(rmap._kernel_args_, cmap._kernel_args_))


@dataclass
Expand Down
4 changes: 3 additions & 1 deletion pyop2/sparsity.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def fill_with_zeros(PETSc.Mat mat not None, dims, maps, iteration_regions, set_d
PetscInt[:, ::1] rmap, cmap, tempmap
PetscInt **rcomposedmaps = NULL
PetscInt **ccomposedmaps = NULL
PetscInt nrcomposedmaps = 0, nccomposedmaps = 0, rset_entry, cset_entry
PetscInt nrcomposedmaps, nccomposedmaps, rset_entry, cset_entry
PetscInt *rvals
PetscInt *cvals
PetscInt *roffset
Expand Down Expand Up @@ -235,6 +235,7 @@ def fill_with_zeros(PETSc.Mat mat not None, dims, maps, iteration_regions, set_d
else:
rflags.append(set_writeable(pair[0])) # Memoryviews require writeable buffers
rmap = pair[0].values_with_halo # Map values
nrcomposedmaps = 0
if isinstance(pair[1], op2.ComposedMap):
m = pair[1].flattened_maps[0]
cflags.append(set_writeable(m))
Expand All @@ -243,6 +244,7 @@ def fill_with_zeros(PETSc.Mat mat not None, dims, maps, iteration_regions, set_d
else:
cflags.append(set_writeable(pair[1]))
cmap = pair[1].values_with_halo
nccomposedmaps = 0
# Handle ComposedMaps
CHKERR(PetscMalloc2(nrcomposedmaps, &rcomposedmaps, nccomposedmaps, &ccomposedmaps))
for i in range(nrcomposedmaps):
Expand Down
16 changes: 11 additions & 5 deletions pyop2/types/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,6 @@ def __init__(self, maps):
if self._initialized:
return
self._maps = maps
if not all(m is None or m.iterset == self.iterset for m in self._maps):
raise ex.MapTypeError("All maps in a MixedMap need to share the same iterset")
# TODO: Think about different communicators on maps (c.f. MixedSet)
# TODO: What if all maps are None?
comms = tuple(m.comm for m in self._maps if m is not None)
Expand Down Expand Up @@ -344,7 +342,11 @@ def split(self):
@utils.cached_property
def iterset(self):
""":class:`MixedSet` mapped from."""
return functools.reduce(lambda a, b: a or b, map(lambda s: s if s is None else s.iterset, self._maps))
s, = set(m.iterset for m in self._maps)
if len(s) == 1:
return functools.reduce(lambda a, b: a or b, map(lambda s: s if s is None else s.iterset, self._maps))
else:
raise RuntimeError("Found multiple itersets.")

@utils.cached_property
def toset(self):
Expand All @@ -356,7 +358,11 @@ def toset(self):
def arity(self):
"""Arity of the mapping: total number of toset elements mapped to per
iterset element."""
return sum(m.arity for m in self._maps)
s, = set(m.iterset for m in self._maps)
if len(s) == 1:
return sum(m.arity for m in self._maps)
else:
raise RuntimeError("Found multiple itersets.")

@utils.cached_property
def arities(self):
Expand Down Expand Up @@ -402,7 +408,7 @@ def offset(self):
@utils.cached_property
def offset_quotient(self):
"""Offsets quotient."""
raise NotImplementedError("offset_quotient not implemented for MixedMap")
return tuple(0 if m is None else m.offset_quotient for m in self._maps)

def __iter__(self):
r"""Yield all :class:`Map`\s when iterated over."""
Expand Down
5 changes: 5 additions & 0 deletions pyop2/types/set.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ def __init__(self, size, name=None, halo=None, comm=None):
# A cache of objects built on top of this set
self._cache = {}

@property
def indices(self):
"""Returns iterator."""
return range(self.total_size)

@utils.cached_property
def core_size(self):
"""Core set size. Owned elements not touching halo elements."""
Expand Down
5 changes: 0 additions & 5 deletions test/unit/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1446,11 +1446,6 @@ def test_mixed_map_split(self, maps):
assert mmap.split[i] == m
assert mmap.split[:-1] == tuple(mmap)[:-1]

def test_mixed_map_nonunique_itset(self, m_iterset_toset, m_set_toset):
"Map toset should be Set."
with pytest.raises(exceptions.MapTypeError):
op2.MixedMap((m_iterset_toset, m_set_toset))

def test_mixed_map_iterset(self, mmap):
"MixedMap iterset should return the common iterset of all Maps."
for m in mmap:
Expand Down

0 comments on commit 594e87b

Please sign in to comment.