Skip to content

Commit

Permalink
added error checking on cylindrical mesh (#2977)
Browse files Browse the repository at this point in the history
Co-authored-by: Paul Romano <[email protected]>
  • Loading branch information
hsameer481 and paulromano authored Jun 10, 2024
1 parent 12a278b commit e971bd1
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 4 deletions.
21 changes: 18 additions & 3 deletions openmc/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -1388,6 +1388,8 @@ def r_grid(self):
@r_grid.setter
def r_grid(self, grid):
cv.check_type('mesh r_grid', grid, Iterable, Real)
cv.check_length('mesh r_grid', grid, 2)
cv.check_increasing('mesh r_grid', grid)
self._r_grid = np.asarray(grid, dtype=float)

@property
Expand All @@ -1397,7 +1399,12 @@ def phi_grid(self):
@phi_grid.setter
def phi_grid(self, grid):
cv.check_type('mesh phi_grid', grid, Iterable, Real)
self._phi_grid = np.asarray(grid, dtype=float)
cv.check_length('mesh phi_grid', grid, 2)
cv.check_increasing('mesh phi_grid', grid)
grid = np.asarray(grid, dtype=float)
if np.any((grid < 0.0) | (grid > 2*pi)):
raise ValueError("phi_grid values must be in [0, 2π].")
self._phi_grid = grid

@property
def z_grid(self):
Expand All @@ -1406,6 +1413,8 @@ def z_grid(self):
@z_grid.setter
def z_grid(self, grid):
cv.check_type('mesh z_grid', grid, Iterable, Real)
cv.check_length('mesh z_grid', grid, 2)
cv.check_increasing('mesh z_grid', grid)
self._z_grid = np.asarray(grid, dtype=float)

@property
Expand Down Expand Up @@ -1840,7 +1849,10 @@ def theta_grid(self, grid):
cv.check_type('mesh theta_grid', grid, Iterable, Real)
cv.check_length('mesh theta_grid', grid, 2)
cv.check_increasing('mesh theta_grid', grid)
self._theta_grid = np.asarray(grid, dtype=float)
grid = np.asarray(grid, dtype=float)
if np.any((grid < 0.0) | (grid > pi)):
raise ValueError("theta_grid values must be in [0, π].")
self._theta_grid = grid

@property
def phi_grid(self):
Expand All @@ -1851,7 +1863,10 @@ def phi_grid(self, grid):
cv.check_type('mesh phi_grid', grid, Iterable, Real)
cv.check_length('mesh phi_grid', grid, 2)
cv.check_increasing('mesh phi_grid', grid)
self._phi_grid = np.asarray(grid, dtype=float)
grid = np.asarray(grid, dtype=float)
if np.any((grid < 0.0) | (grid > 2*pi)):
raise ValueError("phi_grid values must be in [0, 2π].")
self._phi_grid = grid

@property
def _grids(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/mesh_to_vtk/test_vtk_dims.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def test_write_data_to_vtk_round_trip(run_in_tmpdir):

smesh = openmc.SphericalMesh(
r_grid=(0.0, 1.0, 2.0),
theta_grid=(0.0, 2.0, 4.0, 5.0),
theta_grid=(0.0, 0.5, 1.0, 2.0),
phi_grid=(0.0, 3.0, 6.0),
)
rmesh = openmc.RegularMesh()
Expand Down
36 changes: 36 additions & 0 deletions tests/unit_tests/test_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,42 @@ def test_CylindricalMesh_initiation():
openmc.SphericalMesh(('🧇', '🥞'))


def test_invalid_cylindrical_mesh_errors():
# Test invalid r_grid values
with pytest.raises(ValueError):
openmc.CylindricalMesh(r_grid=[5, 1], phi_grid=[0, pi], z_grid=[0, 10])

with pytest.raises(ValueError):
openmc.CylindricalMesh(r_grid=[1, 2, 4, 3], phi_grid=[0, pi], z_grid=[0, 10])

with pytest.raises(ValueError):
openmc.CylindricalMesh(r_grid=[1], phi_grid=[0, pi], z_grid=[0, 10])

# Test invalid phi_grid values
with pytest.raises(ValueError):
openmc.CylindricalMesh(r_grid=[0, 1, 2], phi_grid=[-1, 3], z_grid=[0, 10])

with pytest.raises(ValueError):
openmc.CylindricalMesh(
r_grid=[0, 1, 2],
phi_grid=[0, 2*pi + 0.1],
z_grid=[0, 10]
)

with pytest.raises(ValueError):
openmc.CylindricalMesh(r_grid=[0, 1, 2], phi_grid=[pi], z_grid=[0, 10])

# Test invalid z_grid values
with pytest.raises(ValueError):
openmc.CylindricalMesh(r_grid=[0, 1, 2], phi_grid=[0, pi], z_grid=[5])

with pytest.raises(ValueError):
openmc.CylindricalMesh(r_grid=[0, 1, 2], phi_grid=[0, pi], z_grid=[5, 1])

with pytest.raises(ValueError):
openmc.CylindricalMesh(r_grid=[1, 2, 4, 3], phi_grid=[0, pi], z_grid=[0, 10, 5])


def test_centroids():
# regular mesh
mesh = openmc.RegularMesh()
Expand Down

0 comments on commit e971bd1

Please sign in to comment.