Skip to content

Commit

Permalink
MultiFab Fixture Cleanup via FabArray::clear (#214)
Browse files Browse the repository at this point in the history
* MultiFab Fixture Cleanup via `FabArray::clear`

Using a context manager and calling clear ensures that we will
not hold device memory anymore once we hit `AMReX::Finalize`,
even in the situation where an exception is raised in a test.
This avoids segfaults for failing tests.

* `test_mfab_dtoh_copy`: Clear MFabs

Clear out memory safely on runtime errors.

* Update Stub Files

---------

Co-authored-by: ax3l <[email protected]>
  • Loading branch information
ax3l and ax3l authored Nov 1, 2023
1 parent 056d332 commit 3c73a42
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 79 deletions.
4 changes: 4 additions & 0 deletions src/Base/MultiFab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ void init_MultiFab(py::module &m)
;

py_FabArray_FArrayBox
// define
.def("clear", &FabArray<FArrayBox>::clear)
.def("ok", &FabArray<FArrayBox>::ok)

//.def("array", py::overload_cast< const MFIter& >(&FabArray<FArrayBox>::array))
//.def("const_array", &FabArray<FArrayBox>::const_array)
.def("array", [](FabArray<FArrayBox> & fa, MFIter const & mfi)
Expand Down
2 changes: 2 additions & 0 deletions src/amrex/space1d/amrex_1d_pybind/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3755,6 +3755,7 @@ class FabArray_FArrayBox(FabArrayBase):
arg6: IntVect,
) -> None: ...
def array(self, arg0: MFIter) -> Array4_double: ...
def clear(self) -> None: ...
def const_array(self, arg0: MFIter) -> Array4_double_const: ...
@typing.overload
def fill_boundary(self, cross: bool = False) -> None: ...
Expand All @@ -3779,6 +3780,7 @@ class FabArray_FArrayBox(FabArrayBase):
period: Periodicity,
cross: bool = False,
) -> None: ...
def ok(self) -> bool: ...
def override_sync(self, arg0: Periodicity) -> None: ...
def sum(self, arg0: int, arg1: IntVect, arg2: bool) -> float: ...
@typing.overload
Expand Down
2 changes: 2 additions & 0 deletions src/amrex/space2d/amrex_2d_pybind/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3755,6 +3755,7 @@ class FabArray_FArrayBox(FabArrayBase):
arg6: IntVect,
) -> None: ...
def array(self, arg0: MFIter) -> Array4_double: ...
def clear(self) -> None: ...
def const_array(self, arg0: MFIter) -> Array4_double_const: ...
@typing.overload
def fill_boundary(self, cross: bool = False) -> None: ...
Expand All @@ -3779,6 +3780,7 @@ class FabArray_FArrayBox(FabArrayBase):
period: Periodicity,
cross: bool = False,
) -> None: ...
def ok(self) -> bool: ...
def override_sync(self, arg0: Periodicity) -> None: ...
def sum(self, arg0: int, arg1: IntVect, arg2: bool) -> float: ...
@typing.overload
Expand Down
2 changes: 2 additions & 0 deletions src/amrex/space3d/amrex_3d_pybind/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3755,6 +3755,7 @@ class FabArray_FArrayBox(FabArrayBase):
arg6: IntVect,
) -> None: ...
def array(self, arg0: MFIter) -> Array4_double: ...
def clear(self) -> None: ...
def const_array(self, arg0: MFIter) -> Array4_double_const: ...
@typing.overload
def fill_boundary(self, cross: bool = False) -> None: ...
Expand All @@ -3779,6 +3780,7 @@ class FabArray_FArrayBox(FabArrayBase):
period: Periodicity,
cross: bool = False,
) -> None: ...
def ok(self) -> bool: ...
def override_sync(self, arg0: Periodicity) -> None: ...
def sum(self, arg0: int, arg1: IntVect, arg2: bool) -> float: ...
@typing.overload
Expand Down
58 changes: 35 additions & 23 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,47 +85,59 @@ def distmap(boxarr):


@pytest.fixture(scope="function", params=list(itertools.product([1, 3], [0, 1])))
def make_mfab(boxarr, distmap, request):
def mfab(boxarr, distmap, request):
"""MultiFab that is either managed or device:
The MultiFab object itself is not a fixture because we want to avoid caching
it between amr.initialize/finalize calls of various tests.
https://github.com/pytest-dev/pytest/discussions/10387
https://github.com/pytest-dev/pytest/issues/5642#issuecomment-1279612764
"""

def create():
num_components = request.param[0]
num_ghost = request.param[1]
mfab = amr.MultiFab(boxarr, distmap, num_components, num_ghost)
mfab.set_val(0.0, 0, num_components)
return mfab
class MfabContextManager:
def __enter__(self):
num_components = request.param[0]
num_ghost = request.param[1]
self.mfab = amr.MultiFab(boxarr, distmap, num_components, num_ghost)
self.mfab.set_val(0.0, 0, num_components)
return self.mfab

return create
def __exit__(self, exc_type, exc_value, traceback):
self.mfab.clear()
del self.mfab

with MfabContextManager() as mfab:
yield mfab


@pytest.mark.skipif(
amr.Config.gpu_backend != "CUDA", reason="Requires AMReX_GPU_BACKEND=CUDA"
)
@pytest.fixture(scope="function", params=list(itertools.product([1, 3], [0, 1])))
def make_mfab_device(boxarr, distmap, request):
def mfab_device(boxarr, distmap, request):
"""MultiFab that resides purely on the device:
The MultiFab object itself is not a fixture because we want to avoid caching
it between amr.initialize/finalize calls of various tests.
https://github.com/pytest-dev/pytest/discussions/10387
https://github.com/pytest-dev/pytest/issues/5642#issuecomment-1279612764
"""

def create():
num_components = request.param[0]
num_ghost = request.param[1]
mfab = amr.MultiFab(
boxarr,
distmap,
num_components,
num_ghost,
amr.MFInfo().set_arena(amr.The_Device_Arena()),
)
mfab.set_val(0.0, 0, num_components)
return mfab

return create
class MfabDeviceContextManager:
def __enter__(self):
num_components = request.param[0]
num_ghost = request.param[1]
self.mfab = amr.MultiFab(
boxarr,
distmap,
num_components,
num_ghost,
amr.MFInfo().set_arena(amr.The_Device_Arena()),
)
self.mfab.set_val(0.0, 0, num_components)
return self.mfab

def __exit__(self, exc_type, exc_value, traceback):
self.mfab.clear()
del self.mfab

with MfabDeviceContextManager() as mfab:
yield mfab
114 changes: 58 additions & 56 deletions tests/test_multifab.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
import amrex.space3d as amr


def test_mfab_loop(make_mfab):
mfab = make_mfab()
def test_mfab_loop(mfab):
ngv = mfab.nGrowVect
print(f"\n mfab={mfab}, mfab.nGrowVect={ngv}")

Expand Down Expand Up @@ -78,8 +77,7 @@ def test_mfab_loop(make_mfab):
# TODO


def test_mfab_simple(make_mfab):
mfab = make_mfab()
def test_mfab_simple(mfab):
assert mfab.is_all_cell_centered
# assert(all(not mfab.is_nodal(i) for i in [-1, 0, 1, 2])) # -1??
assert all(not mfab.is_nodal(i) for i in [0, 1, 2])
Expand Down Expand Up @@ -144,8 +142,7 @@ def test_mfab_ops(boxarr, distmap, nghost):
np.testing.assert_allclose(dst.max(0), 150.0)


def test_mfab_mfiter(make_mfab):
mfab = make_mfab()
def test_mfab_mfiter(mfab):
assert iter(mfab).is_valid
assert iter(mfab).length == 8

Expand All @@ -159,8 +156,7 @@ def test_mfab_mfiter(make_mfab):
@pytest.mark.skipif(
amr.Config.gpu_backend != "CUDA", reason="Requires AMReX_GPU_BACKEND=CUDA"
)
def test_mfab_ops_cuda_numba(make_mfab_device):
mfab_device = make_mfab_device()
def test_mfab_ops_cuda_numba(mfab_device):
# https://numba.pydata.org/numba-doc/dev/cuda/cuda_array_interface.html
from numba import cuda

Expand Down Expand Up @@ -195,8 +191,7 @@ def set_to_three(array):
@pytest.mark.skipif(
amr.Config.gpu_backend != "CUDA", reason="Requires AMReX_GPU_BACKEND=CUDA"
)
def test_mfab_ops_cuda_cupy(make_mfab_device):
mfab_device = make_mfab_device()
def test_mfab_ops_cuda_cupy(mfab_device):
# https://docs.cupy.dev/en/stable/user_guide/interoperability.html
import cupy as cp
import cupyx.profiler
Expand Down Expand Up @@ -285,8 +280,7 @@ def set_to_seven(x):
@pytest.mark.skipif(
amr.Config.gpu_backend != "CUDA", reason="Requires AMReX_GPU_BACKEND=CUDA"
)
def test_mfab_ops_cuda_pytorch(make_mfab_device):
mfab_device = make_mfab_device()
def test_mfab_ops_cuda_pytorch(mfab_device):
# https://docs.cupy.dev/en/stable/user_guide/interoperability.html#pytorch
import torch

Expand All @@ -305,8 +299,8 @@ def test_mfab_ops_cuda_pytorch(make_mfab_device):
@pytest.mark.skipif(
amr.Config.gpu_backend != "CUDA", reason="Requires AMReX_GPU_BACKEND=CUDA"
)
def test_mfab_ops_cuda_cuml(make_mfab_device):
mfab_device = make_mfab_device() # noqa
def test_mfab_ops_cuda_cuml(mfab_device):
pass
# https://github.com/rapidsai/cuml
# https://github.com/rapidsai/cudf
# maybe better for particles as a dataframe test
Expand All @@ -322,47 +316,55 @@ def test_mfab_ops_cuda_cuml(make_mfab_device):
@pytest.mark.skipif(
amr.Config.gpu_backend != "CUDA", reason="Requires AMReX_GPU_BACKEND=CUDA"
)
def test_mfab_dtoh_copy(make_mfab_device):
mfab_device = make_mfab_device()

mfab_host = amr.MultiFab(
mfab_device.box_array(),
mfab_device.dm(),
mfab_device.n_comp(),
mfab_device.n_grow_vect(),
amr.MFInfo().set_arena(amr.The_Pinned_Arena()),
)
mfab_host.set_val(42.0)

amr.dtoh_memcpy(mfab_host, mfab_device)

# assert all are 0.0 on host
host_min = mfab_host.min(0)
host_max = mfab_host.max(0)
assert host_min == host_max
assert host_max == 0.0

dev_val = 11.0
mfab_host.set_val(dev_val)
amr.dtoh_memcpy(mfab_device, mfab_host)

# assert all are 11.0 on device
for n in range(mfab_device.n_comp()):
assert mfab_device.min(comp=n) == dev_val
assert mfab_device.max(comp=n) == dev_val

# numpy bindings (w/ copy)
local_boxes_host = mfab_device.to_numpy(copy=True)
assert max([np.max(box) for box in local_boxes_host]) == dev_val

# numpy bindings (w/ copy)
for mfi in mfab_device:
marr = mfab_device.array(mfi).to_numpy(copy=True)
assert np.min(marr) >= dev_val
assert np.max(marr) <= dev_val
def test_mfab_dtoh_copy(mfab_device):
class MfabPinnedContextManager:
def __enter__(self):
self.mfab = amr.MultiFab(
mfab_device.box_array(),
mfab_device.dm(),
mfab_device.n_comp(),
mfab_device.n_grow_vect(),
amr.MFInfo().set_arena(amr.The_Pinned_Arena()),
)
return self.mfab

def __exit__(self, exc_type, exc_value, traceback):
self.mfab.clear()
del self.mfab

with MfabPinnedContextManager() as mfab_host:
mfab_host.set_val(42.0)

amr.dtoh_memcpy(mfab_host, mfab_device)

# assert all are 0.0 on host
host_min = mfab_host.min(0)
host_max = mfab_host.max(0)
assert host_min == host_max
assert host_max == 0.0

dev_val = 11.0
mfab_host.set_val(dev_val)
amr.htod_memcpy(mfab_device, mfab_host)

# assert all are 11.0 on device
for n in range(mfab_device.n_comp()):
assert mfab_device.min(comp=n) == dev_val
assert mfab_device.max(comp=n) == dev_val

# numpy bindings (w/ copy)
local_boxes_host = mfab_device.to_numpy(copy=True)
assert max([np.max(box) for box in local_boxes_host]) == dev_val
del local_boxes_host

# numpy bindings (w/ copy)
for mfi in mfab_device:
marr = mfab_device.array(mfi).to_numpy(copy=True)
assert np.min(marr) >= dev_val
assert np.max(marr) <= dev_val

# cupy bindings (w/o copy)
import cupy as cp
# cupy bindings (w/o copy)
import cupy as cp

local_boxes_device = mfab_device.to_cupy()
assert max([cp.max(box) for box in local_boxes_device]) == dev_val
local_boxes_device = mfab_device.to_cupy()
assert max([cp.max(box) for box in local_boxes_device]) == dev_val

0 comments on commit 3c73a42

Please sign in to comment.